mirror of
https://github.com/Dispatcharr/Dispatcharr.git
synced 2026-01-23 02:35:14 +00:00
Download Sentence Transformer model locally
This commit is contained in:
parent
54a7a3cf86
commit
87fa1b4ed2
2 changed files with 16 additions and 33 deletions
|
|
@ -1,11 +1,13 @@
|
|||
# apps/channels/tasks.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
from celery import shared_task
|
||||
from rapidfuzz import fuzz
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
|
||||
from apps.channels.models import Channel
|
||||
|
|
@ -16,7 +18,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Load the model once at module level
|
||||
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
st_model = SentenceTransformer(SENTENCE_MODEL_NAME)
|
||||
MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2")
|
||||
os.makedirs(MODEL_PATH, exist_ok=True)
|
||||
|
||||
# Only download if not already present
|
||||
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
|
||||
logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
|
||||
st_model = SentenceTransformer(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
|
||||
else:
|
||||
logger.info(f"Loading local model from {MODEL_PATH}")
|
||||
st_model = SentenceTransformer(MODEL_PATH)
|
||||
|
||||
# Threshold constants
|
||||
BEST_FUZZY_THRESHOLD = 70
|
||||
|
|
@ -43,21 +54,12 @@ def normalize_name(name: str) -> str:
|
|||
if not name:
|
||||
return ""
|
||||
|
||||
# Lowercase
|
||||
norm = name.lower()
|
||||
|
||||
# Remove bracketed text
|
||||
norm = re.sub(r"\[.*?\]", "", norm)
|
||||
norm = re.sub(r"\(.*?\)", "", norm)
|
||||
|
||||
# Remove punctuation except word chars/spaces
|
||||
norm = re.sub(r"[^\w\s]", "", norm)
|
||||
|
||||
# Remove extraneous tokens
|
||||
tokens = norm.split()
|
||||
tokens = [t for t in tokens if t not in COMMON_EXTRANEOUS_WORDS]
|
||||
|
||||
# Rejoin
|
||||
norm = " ".join(tokens).strip()
|
||||
return norm
|
||||
|
||||
|
|
@ -76,25 +78,22 @@ def match_epg_channels():
|
|||
"""
|
||||
logger.info("Starting EPG matching logic...")
|
||||
|
||||
# Try to get user's preferred region from CoreSettings
|
||||
try:
|
||||
region_obj = CoreSettings.objects.get(key="preferred-region")
|
||||
region_code = region_obj.value.strip().lower() # e.g. "us"
|
||||
except CoreSettings.DoesNotExist:
|
||||
region_code = None
|
||||
|
||||
# 1) Gather EPG rows
|
||||
all_epg = list(EPGData.objects.all())
|
||||
epg_rows = []
|
||||
for e in all_epg:
|
||||
epg_rows.append({
|
||||
"epg_id": e.id,
|
||||
"tvg_id": e.tvg_id or "", # e.g. "Fox News.us"
|
||||
"tvg_id": e.tvg_id or "",
|
||||
"raw_name": e.name,
|
||||
"norm_name": normalize_name(e.name),
|
||||
})
|
||||
|
||||
# 2) Pre-encode embeddings if possible
|
||||
epg_embeddings = None
|
||||
if any(row["norm_name"] for row in epg_rows):
|
||||
epg_embeddings = st_model.encode(
|
||||
|
|
@ -119,7 +118,7 @@ def match_epg_channels():
|
|||
)
|
||||
continue
|
||||
|
||||
# C) No valid tvg_id => name-based matching
|
||||
# C) Name-based matching
|
||||
fallback_name = chan.tvg_name.strip() if chan.tvg_name else chan.name
|
||||
norm_chan = normalize_name(fallback_name)
|
||||
if not norm_chan:
|
||||
|
|
@ -130,26 +129,16 @@ def match_epg_channels():
|
|||
|
||||
best_score = 0
|
||||
best_epg = None
|
||||
|
||||
for row in epg_rows:
|
||||
if not row["norm_name"]:
|
||||
continue
|
||||
# Base fuzzy ratio
|
||||
base_score = fuzz.ratio(norm_chan, row["norm_name"])
|
||||
|
||||
# If we have a region_code, add a small bonus if the epg row has that region
|
||||
# e.g. tvg_id or raw_name might contain ".us" or "us"
|
||||
bonus = 0
|
||||
if region_code:
|
||||
# example: if region_code is "us" and row["tvg_id"] ends with ".us"
|
||||
# or row["raw_name"] has "us" in it, etc.
|
||||
# We'll do a naive check:
|
||||
combined_text = row["tvg_id"].lower() + " " + row["raw_name"].lower()
|
||||
if region_code in combined_text:
|
||||
bonus = 15 # pick a small bonus
|
||||
|
||||
bonus = 15
|
||||
score = base_score + bonus
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_epg = row
|
||||
|
|
@ -158,9 +147,7 @@ def match_epg_channels():
|
|||
logger.info(f"Channel {chan.id} '{fallback_name}' => no EPG match at all.")
|
||||
continue
|
||||
|
||||
# E) Decide acceptance
|
||||
if best_score >= BEST_FUZZY_THRESHOLD:
|
||||
# Accept
|
||||
chan.tvg_id = best_epg["tvg_id"]
|
||||
chan.save()
|
||||
matched_channels.append((chan.id, fallback_name, best_epg["tvg_id"]))
|
||||
|
|
@ -168,12 +155,10 @@ def match_epg_channels():
|
|||
f"Channel {chan.id} '{fallback_name}' => matched tvg_id={best_epg['tvg_id']} (score={best_score})"
|
||||
)
|
||||
elif best_score >= LOWER_FUZZY_THRESHOLD and epg_embeddings is not None:
|
||||
# borderline => do embedding
|
||||
chan_embedding = st_model.encode(norm_chan, convert_to_tensor=True)
|
||||
sim_scores = util.cos_sim(chan_embedding, epg_embeddings)[0]
|
||||
top_index = int(sim_scores.argmax())
|
||||
top_value = float(sim_scores[top_index])
|
||||
|
||||
if top_value >= EMBED_SIM_THRESHOLD:
|
||||
matched_epg = epg_rows[top_index]
|
||||
chan.tvg_id = matched_epg["tvg_id"]
|
||||
|
|
@ -189,12 +174,10 @@ def match_epg_channels():
|
|||
f"cos-sim={top_value:.2f} < {EMBED_SIM_THRESHOLD}, skipping"
|
||||
)
|
||||
else:
|
||||
# no match
|
||||
logger.info(
|
||||
f"Channel {chan.id} '{fallback_name}' => fuzzy={best_score} < {LOWER_FUZZY_THRESHOLD}, skipping"
|
||||
)
|
||||
|
||||
# Final summary
|
||||
total_matched = len(matched_channels)
|
||||
if total_matched:
|
||||
logger.info(f"Match Summary: {total_matched} channel(s) matched.")
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@
|
|||
"name": "Preferred Region",
|
||||
"value": "us"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"model": "core.coresettings",
|
||||
"fields": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue