mirror of
https://github.com/photoprism/photoprism.git
synced 2026-01-23 02:24:24 +00:00
AI: Generate Labels using the Ollama API #5232
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
cf06f52025
commit
8a7c61f467
6 changed files with 288 additions and 9 deletions
|
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
|
|
@ -48,6 +49,10 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
|||
return apiResponse, clientErr
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = clientResp.Body.Close()
|
||||
}()
|
||||
|
||||
apiResponse = &ApiResponse{}
|
||||
|
||||
// Parse and return response, or an error if the request failed.
|
||||
|
|
@ -61,13 +66,17 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
|||
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
|
||||
}
|
||||
case ApiFormatOllama:
|
||||
ollamaResponse := &ApiResponseOllama{}
|
||||
apiJson, apiErr := io.ReadAll(clientResp.Body)
|
||||
if apiErr != nil {
|
||||
return apiResponse, apiErr
|
||||
}
|
||||
|
||||
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
|
||||
return apiResponse, apiErr
|
||||
} else if apiErr = json.Unmarshal(apiJson, ollamaResponse); apiErr != nil {
|
||||
return apiResponse, apiErr
|
||||
} else if clientResp.StatusCode >= 300 {
|
||||
ollamaResponse, decodeErr := decodeOllamaResponse(apiJson)
|
||||
if decodeErr != nil {
|
||||
return apiResponse, decodeErr
|
||||
}
|
||||
|
||||
if clientResp.StatusCode >= 300 {
|
||||
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
|
||||
}
|
||||
|
||||
|
|
@ -77,9 +86,41 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
|||
Name: ollamaResponse.Model,
|
||||
}
|
||||
|
||||
apiResponse.Result.Caption = &CaptionResult{
|
||||
Text: ollamaResponse.Response,
|
||||
Source: entity.SrcImage,
|
||||
// Copy structured results when provided.
|
||||
if len(ollamaResponse.Result.Labels) > 0 {
|
||||
apiResponse.Result.Labels = append(apiResponse.Result.Labels, ollamaResponse.Result.Labels...)
|
||||
}
|
||||
|
||||
parsedLabels := false
|
||||
|
||||
if len(apiResponse.Result.Labels) > 0 {
|
||||
parsedLabels = true
|
||||
}
|
||||
|
||||
if !parsedLabels {
|
||||
if apiRequest.Format == FormatJSON {
|
||||
if labels, parseErr := parseOllamaLabels(ollamaResponse.Response); parseErr != nil {
|
||||
log.Debugf("vision: %s (parse ollama labels)", clean.Error(parseErr))
|
||||
} else if len(labels) > 0 {
|
||||
apiResponse.Result.Labels = append(apiResponse.Result.Labels, labels...)
|
||||
parsedLabels = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if parsedLabels {
|
||||
for i := range apiResponse.Result.Labels {
|
||||
if apiResponse.Result.Labels[i].Source == "" {
|
||||
apiResponse.Result.Labels[i].Source = entity.SrcVision
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if caption := strings.TrimSpace(ollamaResponse.Response); caption != "" {
|
||||
apiResponse.Result.Caption = &CaptionResult{
|
||||
Text: caption,
|
||||
Source: entity.SrcImage,
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return apiResponse, fmt.Errorf("unsupported response format %s", clean.Log(apiRequest.ResponseFormat))
|
||||
|
|
@ -87,3 +128,39 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
|
|||
|
||||
return apiResponse, nil
|
||||
}
|
||||
|
||||
func decodeOllamaResponse(data []byte) (*ApiResponseOllama, error) {
|
||||
resp := &ApiResponseOllama{}
|
||||
dec := json.NewDecoder(bytes.NewReader(data))
|
||||
|
||||
for {
|
||||
var chunk ApiResponseOllama
|
||||
if err := dec.Decode(&chunk); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
*resp = chunk
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseOllamaLabels(raw string) ([]LabelResult, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Labels []LabelResult `json:"labels"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payload.Labels, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
package vision
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
|
@ -39,3 +42,57 @@ func TestNewApiRequest(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPerformApiRequestOllama(t *testing.T) {
|
||||
t.Run("Labels", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req ApiRequest
|
||||
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
|
||||
assert.Equal(t, FormatJSON, req.Format)
|
||||
assert.NoError(t, json.NewEncoder(w).Encode(ApiResponseOllama{
|
||||
Model: "qwen2.5vl:latest",
|
||||
Response: `{"labels":[{"name":"test","confidence":0.9,"topicality":0.8}]}`,
|
||||
}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
apiRequest := &ApiRequest{
|
||||
Id: "test",
|
||||
Model: "qwen2.5vl:latest",
|
||||
Format: FormatJSON,
|
||||
Images: []string{"data:image/jpeg;base64,AA=="},
|
||||
ResponseFormat: ApiFormatOllama,
|
||||
}
|
||||
|
||||
resp, err := PerformApiRequest(apiRequest, server.URL, http.MethodPost, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, resp.Result.Labels, 1)
|
||||
assert.Equal(t, "test", resp.Result.Labels[0].Name)
|
||||
assert.Nil(t, resp.Result.Caption)
|
||||
})
|
||||
|
||||
t.Run("CaptionFallback", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NoError(t, json.NewEncoder(w).Encode(ApiResponseOllama{
|
||||
Model: "qwen2.5vl:latest",
|
||||
Response: "plain text",
|
||||
}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
apiRequest := &ApiRequest{
|
||||
Id: "test2",
|
||||
Model: "qwen2.5vl:latest",
|
||||
Format: FormatJSON,
|
||||
Images: []string{"data:image/jpeg;base64,AA=="},
|
||||
ResponseFormat: ApiFormatOllama,
|
||||
}
|
||||
|
||||
resp, err := PerformApiRequest(apiRequest, server.URL, http.MethodPost, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, resp.Result.Labels, 0)
|
||||
if assert.NotNil(t, resp.Result.Caption) {
|
||||
assert.Equal(t, "plain text", resp.Result.Caption.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ import (
|
|||
|
||||
type Files = []string
|
||||
|
||||
const (
|
||||
FormatJSON = "json"
|
||||
)
|
||||
|
||||
// ApiRequestOptions represents additional model parameters listed in the documentation.
|
||||
type ApiRequestOptions struct {
|
||||
NumKeep int `yaml:"NumKeep,omitempty" json:"num_keep,omitempty"`
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
|
|
@ -30,6 +31,10 @@ func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.
|
|||
return result, err
|
||||
}
|
||||
|
||||
if format := model.GetFormat(); format != "" {
|
||||
apiRequest.Format = format
|
||||
}
|
||||
|
||||
switch model.Service.RequestFormat {
|
||||
case ApiFormatOllama:
|
||||
apiRequest.Model, _, _ = model.Model()
|
||||
|
|
@ -45,6 +50,15 @@ func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.
|
|||
apiRequest.Prompt = model.Prompt
|
||||
}
|
||||
|
||||
if schemaPrompt := model.SchemaInstructions(); schemaPrompt != "" {
|
||||
prompt := strings.TrimSpace(apiRequest.Prompt)
|
||||
if prompt != "" {
|
||||
apiRequest.Prompt = fmt.Sprintf("%s\n\n%s", prompt, schemaPrompt)
|
||||
} else {
|
||||
apiRequest.Prompt = schemaPrompt
|
||||
}
|
||||
}
|
||||
|
||||
// Log JSON request data in trace mode.
|
||||
apiRequest.WriteLog()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package vision
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -10,11 +11,14 @@ import (
|
|||
"github.com/photoprism/photoprism/internal/ai/nsfw"
|
||||
"github.com/photoprism/photoprism/internal/ai/tensorflow"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/scheme"
|
||||
)
|
||||
|
||||
var modelMutex = sync.Mutex{}
|
||||
|
||||
const labelSchemaEnvVar = "PHOTOPRISM_VISION_LABEL_SCHEMA_FILE"
|
||||
|
||||
// Default model version strings.
|
||||
var (
|
||||
VersionLatest = "latest"
|
||||
|
|
@ -30,6 +34,9 @@ type Model struct {
|
|||
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
|
||||
System string `yaml:"System,omitempty" json:"system,omitempty"`
|
||||
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
|
||||
Format string `yaml:"Format,omitempty" json:"format,omitempty"`
|
||||
Schema string `yaml:"Schema,omitempty" json:"schema,omitempty"`
|
||||
SchemaFile string `yaml:"SchemaFile,omitempty" json:"schemaFile,omitempty"`
|
||||
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
|
||||
TensorFlow *tensorflow.ModelInfo `yaml:"TensorFlow,omitempty" json:"tensorflow,omitempty"`
|
||||
Options *ApiRequestOptions `yaml:"Options,omitempty" json:"options,omitempty"`
|
||||
|
|
@ -39,6 +46,8 @@ type Model struct {
|
|||
classifyModel *classify.Model
|
||||
faceModel *face.Model
|
||||
nsfwModel *nsfw.Model
|
||||
schemaOnce sync.Once
|
||||
schema string
|
||||
}
|
||||
|
||||
// Models represents a set of computer vision models.
|
||||
|
|
@ -152,6 +161,19 @@ func (m *Model) GetSystemPrompt() string {
|
|||
}
|
||||
}
|
||||
|
||||
// GetFormat returns the configured response format or a sensible default.
|
||||
func (m *Model) GetFormat() string {
|
||||
if f := strings.TrimSpace(strings.ToLower(m.Format)); f != "" {
|
||||
return f
|
||||
}
|
||||
|
||||
if m.Type == ModelTypeLabels && m.EndpointResponseFormat() == ApiFormatOllama {
|
||||
return FormatJSON
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetOptions returns the API request options.
|
||||
func (m *Model) GetOptions() *ApiRequestOptions {
|
||||
if m.Options != nil {
|
||||
|
|
@ -174,6 +196,56 @@ func (m *Model) GetOptions() *ApiRequestOptions {
|
|||
}
|
||||
}
|
||||
|
||||
// SchemaTemplate returns the model-specific JSON schema template, if any.
|
||||
func (m *Model) SchemaTemplate() string {
|
||||
m.schemaOnce.Do(func() {
|
||||
var schema string
|
||||
|
||||
if m.Type == ModelTypeLabels {
|
||||
if envFile := strings.TrimSpace(os.Getenv(labelSchemaEnvVar)); envFile != "" {
|
||||
path := fs.Abs(envFile)
|
||||
if path == "" {
|
||||
path = envFile
|
||||
}
|
||||
if data, err := os.ReadFile(path); err != nil {
|
||||
log.Warnf("vision: failed to read schema from %s (%s)", clean.Log(path), err)
|
||||
} else {
|
||||
schema = string(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if schema == "" && strings.TrimSpace(m.Schema) != "" {
|
||||
schema = m.Schema
|
||||
}
|
||||
|
||||
if schema == "" && strings.TrimSpace(m.SchemaFile) != "" {
|
||||
path := fs.Abs(m.SchemaFile)
|
||||
if path == "" {
|
||||
path = m.SchemaFile
|
||||
}
|
||||
if data, err := os.ReadFile(path); err != nil {
|
||||
log.Warnf("vision: failed to read schema from %s (%s)", clean.Log(path), err)
|
||||
} else {
|
||||
schema = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
m.schema = strings.TrimSpace(schema)
|
||||
})
|
||||
|
||||
return m.schema
|
||||
}
|
||||
|
||||
// SchemaInstructions returns a helper string that can be appended to prompts.
|
||||
func (m *Model) SchemaInstructions() string {
|
||||
if schema := m.SchemaTemplate(); schema != "" {
|
||||
return fmt.Sprintf("Return JSON that matches this schema:\n%s", schema)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ClassifyModel returns the matching classify model instance, if any.
|
||||
func (m *Model) ClassifyModel() *classify.Model {
|
||||
// Use mutex to prevent models from being loaded and
|
||||
|
|
|
|||
|
|
@ -2,10 +2,13 @@ package vision
|
|||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/scheme"
|
||||
)
|
||||
|
||||
|
|
@ -71,3 +74,55 @@ func TestParseTypes(t *testing.T) {
|
|||
assert.Equal(t, ModelTypes{}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelFormatAndSchema(t *testing.T) {
|
||||
t.Run("DefaultOllamaFormat", func(t *testing.T) {
|
||||
m := &Model{
|
||||
Type: ModelTypeLabels,
|
||||
Service: Service{
|
||||
RequestFormat: ApiFormatOllama,
|
||||
ResponseFormat: ApiFormatOllama,
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, FormatJSON, m.GetFormat())
|
||||
})
|
||||
|
||||
t.Run("InlineSchema", func(t *testing.T) {
|
||||
schema := "{\n \"labels\": []\n}"
|
||||
m := &Model{Schema: schema}
|
||||
|
||||
assert.Equal(t, schema, m.SchemaTemplate())
|
||||
assert.Contains(t, m.SchemaInstructions(), "Return JSON")
|
||||
})
|
||||
|
||||
t.Run("SchemaFileAndEnv", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
filePath := filepath.Join(tempDir, "schema.json")
|
||||
content := "{\n \"labels\": [{\"name\": \"test\"}]\n}"
|
||||
assert.NoError(t, os.WriteFile(filePath, []byte(content), fs.ModeConfigFile))
|
||||
|
||||
m := &Model{
|
||||
Type: ModelTypeLabels,
|
||||
SchemaFile: filePath,
|
||||
}
|
||||
|
||||
// First read should use file content.
|
||||
assert.Equal(t, content, m.SchemaTemplate())
|
||||
|
||||
// Reset and use env override with a different file.
|
||||
otherFile := filepath.Join(tempDir, "schema-override.json")
|
||||
otherContent := "{\n \"labels\": []\n, \"markers\": []\n}"
|
||||
assert.NoError(t, os.WriteFile(otherFile, []byte(otherContent), fs.ModeConfigFile))
|
||||
|
||||
t.Setenv(labelSchemaEnvVar, otherFile)
|
||||
|
||||
m2 := &Model{Type: ModelTypeLabels}
|
||||
assert.Equal(t, otherContent, m2.SchemaTemplate())
|
||||
})
|
||||
|
||||
t.Run("FormatOverride", func(t *testing.T) {
|
||||
m := &Model{Format: "JSON"}
|
||||
assert.Equal(t, FormatJSON, m.GetFormat())
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue