mirror of
https://github.com/Dispatcharr/Dispatcharr.git
synced 2026-01-23 02:35:14 +00:00
lazy-load sentencetransformer instance
This commit is contained in:
parent
7f0c426206
commit
cdf9df03bd
3 changed files with 24 additions and 21 deletions
|
|
@ -19,7 +19,7 @@ from core.models import CoreSettings
|
|||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from core.apps import st_model
|
||||
from core.utils import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -69,6 +69,8 @@ def match_epg_channels():
|
|||
"""
|
||||
logger.info("Starting EPG matching logic...")
|
||||
|
||||
st_model = SentenceTransformer.get_model()
|
||||
|
||||
# Attempt to retrieve a "preferred-region" if configured
|
||||
try:
|
||||
region_obj = CoreSettings.objects.get(key="preferred-region")
|
||||
|
|
|
|||
20
core/apps.py
20
core/apps.py
|
|
@ -2,26 +2,6 @@ from django.apps import AppConfig
|
|||
from django.conf import settings
|
||||
import os, logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
st_model = None
|
||||
|
||||
class CoreConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'core'
|
||||
|
||||
def ready(self):
|
||||
global st_model
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Load the sentence-transformers model once at the module level
|
||||
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2")
|
||||
os.makedirs(MODEL_PATH, exist_ok=True)
|
||||
|
||||
# If not present locally, download:
|
||||
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
|
||||
logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
|
||||
st_model = SentenceTransformer(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
|
||||
else:
|
||||
logger.info(f"Loading local model from {MODEL_PATH}")
|
||||
st_model = SentenceTransformer(MODEL_PATH)
|
||||
|
|
|
|||
|
|
@ -159,3 +159,24 @@ def send_websocket_event(event, success, data):
|
|||
"data": {"success": True, "type": "epg_channels"}
|
||||
}
|
||||
)
|
||||
|
||||
class SentenceTransformer
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def get_model(cls):
|
||||
if cls._instance is None:
|
||||
from sentence_transformers import SentenceTransformer as st
|
||||
|
||||
# Load the sentence-transformers model once at the module level
|
||||
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2")
|
||||
os.makedirs(MODEL_PATH, exist_ok=True)
|
||||
|
||||
# If not present locally, download:
|
||||
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
|
||||
logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
|
||||
st_model = st(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
|
||||
else:
|
||||
logger.info(f"Loading local model from {MODEL_PATH}")
|
||||
st_model = st(MODEL_PATH)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue