diff --git a/apps/channels/tasks.py b/apps/channels/tasks.py index 9ac4d475..4fd11adf 100644 --- a/apps/channels/tasks.py +++ b/apps/channels/tasks.py @@ -19,7 +19,7 @@ from core.models import CoreSettings from asgiref.sync import async_to_sync from channels.layers import get_channel_layer -from core.apps import st_model +from core.utils import SentenceTransformer logger = logging.getLogger(__name__) @@ -69,6 +69,8 @@ def match_epg_channels(): """ logger.info("Starting EPG matching logic...") + st_model = SentenceTransformer.get_model() + # Attempt to retrieve a "preferred-region" if configured try: region_obj = CoreSettings.objects.get(key="preferred-region") diff --git a/core/apps.py b/core/apps.py index 0d23849c..3a01f0bd 100644 --- a/core/apps.py +++ b/core/apps.py @@ -2,26 +2,6 @@ from django.apps import AppConfig from django.conf import settings import os, logging -logger = logging.getLogger(__name__) -st_model = None - class CoreConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'core' - - def ready(self): - global st_model - from sentence_transformers import SentenceTransformer - - # Load the sentence-transformers model once at the module level - SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" - MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2") - os.makedirs(MODEL_PATH, exist_ok=True) - - # If not present locally, download: - if not os.path.exists(os.path.join(MODEL_PATH, "config.json")): - logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...") - st_model = SentenceTransformer(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH) - else: - logger.info(f"Loading local model from {MODEL_PATH}") - st_model = SentenceTransformer(MODEL_PATH) diff --git a/core/utils.py b/core/utils.py index ca6fa75f..16b96a18 100644 --- a/core/utils.py +++ b/core/utils.py @@ -159,3 +159,24 @@ def send_websocket_event(event, success, data): "data": {"success": True, "type": "epg_channels"} } ) + +class SentenceTransformer + _instance = None + + @classmethod + def get_model(cls): + if cls._instance is None: + from sentence_transformers import SentenceTransformer as st + + # Load the sentence-transformers model once at the module level + SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" + MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2") + os.makedirs(MODEL_PATH, exist_ok=True) + + # If not present locally, download: + if not os.path.exists(os.path.join(MODEL_PATH, "config.json")): + logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...") + st_model = st(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH) + else: + logger.info(f"Loading local model from {MODEL_PATH}") + st_model = st(MODEL_PATH)