diff --git a/apps/channels/tasks.py b/apps/channels/tasks.py index c1d16e1b..526bb8f2 100644 --- a/apps/channels/tasks.py +++ b/apps/channels/tasks.py @@ -1,11 +1,13 @@ # apps/channels/tasks.py import logging +import os import re from celery import shared_task from rapidfuzz import fuzz from sentence_transformers import SentenceTransformer, util +from django.conf import settings from django.db import transaction from apps.channels.models import Channel @@ -16,7 +18,16 @@ logger = logging.getLogger(__name__) # Load the model once at module level SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" -st_model = SentenceTransformer(SENTENCE_MODEL_NAME) +MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2") +os.makedirs(MODEL_PATH, exist_ok=True) + +# Only download if not already present +if not os.path.exists(os.path.join(MODEL_PATH, "config.json")): + logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...") + st_model = SentenceTransformer(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH) +else: + logger.info(f"Loading local model from {MODEL_PATH}") + st_model = SentenceTransformer(MODEL_PATH) # Threshold constants BEST_FUZZY_THRESHOLD = 70 @@ -43,21 +54,12 @@ def normalize_name(name: str) -> str: if not name: return "" - # Lowercase norm = name.lower() - - # Remove bracketed text norm = re.sub(r"\[.*?\]", "", norm) norm = re.sub(r"\(.*?\)", "", norm) - - # Remove punctuation except word chars/spaces norm = re.sub(r"[^\w\s]", "", norm) - - # Remove extraneous tokens tokens = norm.split() tokens = [t for t in tokens if t not in COMMON_EXTRANEOUS_WORDS] - - # Rejoin norm = " ".join(tokens).strip() return norm @@ -76,25 +78,22 @@ def match_epg_channels(): """ logger.info("Starting EPG matching logic...") - # Try to get user's preferred region from CoreSettings try: region_obj = CoreSettings.objects.get(key="preferred-region") region_code = region_obj.value.strip().lower() # e.g. "us" except CoreSettings.DoesNotExist: region_code = None - # 1) Gather EPG rows all_epg = list(EPGData.objects.all()) epg_rows = [] for e in all_epg: epg_rows.append({ "epg_id": e.id, - "tvg_id": e.tvg_id or "", # e.g. "Fox News.us" + "tvg_id": e.tvg_id or "", "raw_name": e.name, "norm_name": normalize_name(e.name), }) - # 2) Pre-encode embeddings if possible epg_embeddings = None if any(row["norm_name"] for row in epg_rows): epg_embeddings = st_model.encode( @@ -119,7 +118,7 @@ def match_epg_channels(): ) continue - # C) No valid tvg_id => name-based matching + # C) Name-based matching fallback_name = chan.tvg_name.strip() if chan.tvg_name else chan.name norm_chan = normalize_name(fallback_name) if not norm_chan: @@ -130,26 +129,16 @@ def match_epg_channels(): best_score = 0 best_epg = None - for row in epg_rows: if not row["norm_name"]: continue - # Base fuzzy ratio base_score = fuzz.ratio(norm_chan, row["norm_name"]) - - # If we have a region_code, add a small bonus if the epg row has that region - # e.g. tvg_id or raw_name might contain ".us" or "us" bonus = 0 if region_code: - # example: if region_code is "us" and row["tvg_id"] ends with ".us" - # or row["raw_name"] has "us" in it, etc. - # We'll do a naive check: combined_text = row["tvg_id"].lower() + " " + row["raw_name"].lower() if region_code in combined_text: - bonus = 15 # pick a small bonus - + bonus = 15 score = base_score + bonus - if score > best_score: best_score = score best_epg = row @@ -158,9 +147,7 @@ def match_epg_channels(): logger.info(f"Channel {chan.id} '{fallback_name}' => no EPG match at all.") continue - # E) Decide acceptance if best_score >= BEST_FUZZY_THRESHOLD: - # Accept chan.tvg_id = best_epg["tvg_id"] chan.save() matched_channels.append((chan.id, fallback_name, best_epg["tvg_id"])) @@ -168,12 +155,10 @@ def match_epg_channels(): f"Channel {chan.id} '{fallback_name}' => matched tvg_id={best_epg['tvg_id']} (score={best_score})" ) elif best_score >= LOWER_FUZZY_THRESHOLD and epg_embeddings is not None: - # borderline => do embedding chan_embedding = st_model.encode(norm_chan, convert_to_tensor=True) sim_scores = util.cos_sim(chan_embedding, epg_embeddings)[0] top_index = int(sim_scores.argmax()) top_value = float(sim_scores[top_index]) - if top_value >= EMBED_SIM_THRESHOLD: matched_epg = epg_rows[top_index] chan.tvg_id = matched_epg["tvg_id"] @@ -189,12 +174,10 @@ def match_epg_channels(): f"cos-sim={top_value:.2f} < {EMBED_SIM_THRESHOLD}, skipping" ) else: - # no match logger.info( f"Channel {chan.id} '{fallback_name}' => fuzzy={best_score} < {LOWER_FUZZY_THRESHOLD}, skipping" ) - # Final summary total_matched = len(matched_channels) if total_matched: logger.info(f"Match Summary: {total_matched} channel(s) matched.") diff --git a/fixtures.json b/fixtures.json index 491b052f..2d42f84e 100644 --- a/fixtures.json +++ b/fixtures.json @@ -77,7 +77,7 @@ "name": "Preferred Region", "value": "us" } - } + }, { "model": "core.coresettings", "fields": {