lazy-load sentencetransformer instance

This commit is contained in:
dekzter 2025-04-05 20:05:45 -04:00
parent 7f0c426206
commit cdf9df03bd
3 changed files with 24 additions and 21 deletions

View file

@ -19,7 +19,7 @@ from core.models import CoreSettings
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from core.apps import st_model
from core.utils import SentenceTransformer
logger = logging.getLogger(__name__)
@ -69,6 +69,8 @@ def match_epg_channels():
"""
logger.info("Starting EPG matching logic...")
st_model = SentenceTransformer.get_model()
# Attempt to retrieve a "preferred-region" if configured
try:
region_obj = CoreSettings.objects.get(key="preferred-region")

View file

@ -2,26 +2,6 @@ from django.apps import AppConfig
from django.conf import settings
import os, logging
logger = logging.getLogger(__name__)
st_model = None
class CoreConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'core'
def ready(self):
global st_model
from sentence_transformers import SentenceTransformer
# Load the sentence-transformers model once at the module level
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2")
os.makedirs(MODEL_PATH, exist_ok=True)
# If not present locally, download:
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
st_model = SentenceTransformer(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
else:
logger.info(f"Loading local model from {MODEL_PATH}")
st_model = SentenceTransformer(MODEL_PATH)

View file

@ -159,3 +159,24 @@ def send_websocket_event(event, success, data):
"data": {"success": True, "type": "epg_channels"}
}
)
class SentenceTransformer
_instance = None
@classmethod
def get_model(cls):
if cls._instance is None:
from sentence_transformers import SentenceTransformer as st
# Load the sentence-transformers model once at the module level
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2")
os.makedirs(MODEL_PATH, exist_ok=True)
# If not present locally, download:
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
st_model = st(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
else:
logger.info(f"Loading local model from {MODEL_PATH}")
st_model = st(MODEL_PATH)