forked from Mirrors/Dispatcharr
epg match run externally to keep memory usage low
This commit is contained in:
parent
7a90cc8ae3
commit
5570562960
4 changed files with 205 additions and 160 deletions
|
|
@ -4,17 +4,15 @@ import os
|
|||
import re
|
||||
import requests
|
||||
import time
|
||||
import gc
|
||||
import json
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from rapidfuzz import fuzz
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.utils.text import slugify
|
||||
|
||||
from apps.channels.models import Channel
|
||||
from apps.epg.models import EPGData, EPGSource
|
||||
from apps.epg.models import EPGData
|
||||
from core.models import CoreSettings
|
||||
|
||||
from channels.layers import get_channel_layer
|
||||
|
|
@ -22,15 +20,10 @@ from asgiref.sync import async_to_sync
|
|||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from core.utils import SentenceTransformer
|
||||
import tempfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thresholds
|
||||
BEST_FUZZY_THRESHOLD = 85
|
||||
LOWER_FUZZY_THRESHOLD = 40
|
||||
EMBED_SIM_THRESHOLD = 0.65
|
||||
|
||||
# Words we remove to help with fuzzy + embedding matching
|
||||
COMMON_EXTRANEOUS_WORDS = [
|
||||
"tv", "channel", "network", "television",
|
||||
|
|
@ -70,12 +63,8 @@ def match_epg_channels():
|
|||
4) If a match is found, we set channel.tvg_id
|
||||
5) Summarize and log results.
|
||||
"""
|
||||
from sentence_transformers import util
|
||||
|
||||
logger.info("Starting EPG matching logic...")
|
||||
|
||||
st_model = SentenceTransformer.get_model()
|
||||
|
||||
# Attempt to retrieve a "preferred-region" if configured
|
||||
try:
|
||||
region_obj = CoreSettings.objects.get(key="preferred-region")
|
||||
|
|
@ -83,130 +72,61 @@ def match_epg_channels():
|
|||
except CoreSettings.DoesNotExist:
|
||||
region_code = None
|
||||
|
||||
# Gather EPGData rows so we can do fuzzy matching in memory
|
||||
all_epg = {e.id: e for e in EPGData.objects.all()}
|
||||
|
||||
epg_rows = []
|
||||
for e in list(all_epg.values()):
|
||||
epg_rows.append({
|
||||
"epg_id": e.id,
|
||||
"tvg_id": e.tvg_id or "",
|
||||
"raw_name": e.name,
|
||||
"norm_name": normalize_name(e.name),
|
||||
})
|
||||
|
||||
epg_embeddings = None
|
||||
if any(row["norm_name"] for row in epg_rows):
|
||||
epg_embeddings = st_model.encode(
|
||||
[row["norm_name"] for row in epg_rows],
|
||||
convert_to_tensor=True
|
||||
)
|
||||
|
||||
matched_channels = []
|
||||
channels_to_update = []
|
||||
|
||||
source = EPGSource.objects.filter(is_active=True).first()
|
||||
epg_file_path = getattr(source, 'file_path', None) if source else None
|
||||
channels_json = [{
|
||||
"id": channel.id,
|
||||
"name": channel.name,
|
||||
"tvg_id": channel.tvg_id,
|
||||
"fallback_name": channel.tvg_id.strip() if channel.tvg_id else channel.name,
|
||||
"norm_chan": normalize_name(channel.tvg_id.strip() if channel.tvg_id else channel.name)
|
||||
} for channel in Channel.objects.all() if not channel.epg_data]
|
||||
|
||||
with transaction.atomic():
|
||||
for chan in Channel.objects.all():
|
||||
# skip if channel already assigned an EPG
|
||||
if chan.epg_data:
|
||||
continue
|
||||
epg_json = [{
|
||||
'id': epg.id,
|
||||
'tvg_id': epg.tvg_id,
|
||||
'name': epg.name,
|
||||
'norm_name': normalize_name(epg.name),
|
||||
'epg_source_id': epg.epg_source.id,
|
||||
} for epg in EPGData.objects.all()]
|
||||
|
||||
# If channel has a tvg_id that doesn't exist in EPGData, do direct check.
|
||||
# I don't THINK this should happen now that we assign EPG on channel creation.
|
||||
if chan.tvg_id:
|
||||
epg_match = EPGData.objects.filter(tvg_id=chan.tvg_id).first()
|
||||
if epg_match:
|
||||
chan.epg_data = epg_match
|
||||
logger.info(f"Channel {chan.id} '{chan.name}' => EPG found by tvg_id={chan.tvg_id}")
|
||||
channels_to_update.append(chan)
|
||||
continue
|
||||
payload = {
|
||||
"channels": channels_json,
|
||||
"epg_data": epg_json,
|
||||
"region_code": region_code,
|
||||
}
|
||||
|
||||
# C) Perform name-based fuzzy matching
|
||||
fallback_name = chan.tvg_id.strip() if chan.tvg_id else chan.name
|
||||
norm_chan = normalize_name(fallback_name)
|
||||
if not norm_chan:
|
||||
logger.info(f"Channel {chan.id} '{chan.name}' => empty after normalization, skipping")
|
||||
continue
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(json.dumps(payload).encode('utf-8'))
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
best_score = 0
|
||||
best_epg = None
|
||||
for row in epg_rows:
|
||||
if not row["norm_name"]:
|
||||
continue
|
||||
base_score = fuzz.ratio(norm_chan, row["norm_name"])
|
||||
bonus = 0
|
||||
# Region-based bonus/penalty
|
||||
combined_text = row["tvg_id"].lower() + " " + row["raw_name"].lower()
|
||||
dot_regions = re.findall(r'\.([a-z]{2})', combined_text)
|
||||
if region_code:
|
||||
if dot_regions:
|
||||
if region_code in dot_regions:
|
||||
bonus = 30 # bigger bonus if .us or .ca matches
|
||||
else:
|
||||
bonus = -15
|
||||
elif region_code in combined_text:
|
||||
bonus = 15
|
||||
score = base_score + bonus
|
||||
process = subprocess.Popen(
|
||||
['python', '/app/scripts/epg_match.py', temp_file_path],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Channel {chan.id} '{fallback_name}' => EPG row {row['epg_id']}: "
|
||||
f"raw_name='{row['raw_name']}', norm_name='{row['norm_name']}', "
|
||||
f"combined_text='{combined_text}', dot_regions={dot_regions}, "
|
||||
f"base_score={base_score}, bonus={bonus}, total_score={score}"
|
||||
)
|
||||
# Log stderr in real-time
|
||||
for line in iter(process.stderr.readline, ''):
|
||||
if line:
|
||||
logger.info(line.strip())
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_epg = row
|
||||
process.stderr.close()
|
||||
stdout, stderr = process.communicate()
|
||||
|
||||
# If no best match was found, skip
|
||||
if not best_epg:
|
||||
logger.info(f"Channel {chan.id} '{fallback_name}' => no EPG match at all.")
|
||||
continue
|
||||
os.remove(temp_file_path)
|
||||
|
||||
# If best_score is above BEST_FUZZY_THRESHOLD => direct accept
|
||||
if best_score >= BEST_FUZZY_THRESHOLD:
|
||||
chan.epg_data = all_epg[best_epg["epg_id"]]
|
||||
chan.save()
|
||||
if process.returncode != 0:
|
||||
return f"Failed to process EPG matching: {stderr}"
|
||||
|
||||
matched_channels.append((chan.id, fallback_name, best_epg["tvg_id"]))
|
||||
logger.info(
|
||||
f"Channel {chan.id} '{fallback_name}' => matched tvg_id={best_epg['tvg_id']} "
|
||||
f"(score={best_score})"
|
||||
)
|
||||
result = json.loads(stdout)
|
||||
channels_to_update = result["channels_to_update"]
|
||||
matched_channels = result["matched_channels"]
|
||||
|
||||
# If best_score is in the “middle range,” do embedding check
|
||||
elif best_score >= LOWER_FUZZY_THRESHOLD and epg_embeddings is not None:
|
||||
chan_embedding = st_model.encode(norm_chan, convert_to_tensor=True)
|
||||
sim_scores = util.cos_sim(chan_embedding, epg_embeddings)[0]
|
||||
top_index = int(sim_scores.argmax())
|
||||
top_value = float(sim_scores[top_index])
|
||||
if top_value >= EMBED_SIM_THRESHOLD:
|
||||
matched_epg = epg_rows[top_index]
|
||||
chan.epg_data = all_epg[matched_epg["epg_id"]]
|
||||
chan.save()
|
||||
|
||||
matched_channels.append((chan.id, fallback_name, matched_epg["tvg_id"]))
|
||||
logger.info(
|
||||
f"Channel {chan.id} '{fallback_name}' => matched EPG tvg_id={matched_epg['tvg_id']} "
|
||||
f"(fuzzy={best_score}, cos-sim={top_value:.2f})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Channel {chan.id} '{fallback_name}' => fuzzy={best_score}, "
|
||||
f"cos-sim={top_value:.2f} < {EMBED_SIM_THRESHOLD}, skipping"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Channel {chan.id} '{fallback_name}' => fuzzy={best_score} < "
|
||||
f"{LOWER_FUZZY_THRESHOLD}, skipping"
|
||||
)
|
||||
|
||||
if channels_to_update:
|
||||
Channel.objects.bulk_update(channels_to_update, ['epg_data'])
|
||||
if channels_to_update:
|
||||
Channel.objects.bulk_update(channels_to_update, ['epg_data'])
|
||||
|
||||
total_matched = len(matched_channels)
|
||||
if total_matched:
|
||||
|
|
@ -227,8 +147,6 @@ def match_epg_channels():
|
|||
}
|
||||
)
|
||||
|
||||
SentenceTransformer.clear()
|
||||
gc.collect()
|
||||
return f"Done. Matched {total_matched} channel(s)."
|
||||
|
||||
@shared_task
|
||||
|
|
|
|||
|
|
@ -160,34 +160,3 @@ def send_websocket_event(event, success, data):
|
|||
"data": {"success": True, "type": "epg_channels"}
|
||||
}
|
||||
)
|
||||
|
||||
class SentenceTransformer:
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
def get_model(cls):
|
||||
if cls._instance is None:
|
||||
from sentence_transformers import SentenceTransformer as st
|
||||
|
||||
# Load the sentence-transformers model once at the module level
|
||||
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
MODEL_PATH = os.path.join(settings.MEDIA_ROOT, "models", "all-MiniLM-L6-v2")
|
||||
os.makedirs(MODEL_PATH, exist_ok=True)
|
||||
|
||||
# If not present locally, download:
|
||||
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
|
||||
logger.info(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
|
||||
cls._instance = st(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
|
||||
else:
|
||||
logger.info(f"Loading local model from {MODEL_PATH}")
|
||||
cls._instance = st(MODEL_PATH)
|
||||
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def clear(cls):
|
||||
"""Clear the model instance and release memory."""
|
||||
if cls._instance is not None:
|
||||
del cls._instance
|
||||
cls._instance = None
|
||||
gc.collect()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from datetime import timedelta
|
||||
from celery.schedules import crontab
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
|
|
|||
159
scripts/epg_match.py
Normal file
159
scripts/epg_match.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
# ml_model.py
|
||||
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
from rapidfuzz import fuzz
|
||||
from sentence_transformers import util
|
||||
from sentence_transformers import SentenceTransformer as st
|
||||
|
||||
# Load the sentence-transformers model once at the module level
|
||||
SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
MODEL_PATH = os.path.join("/app/media", "models", "all-MiniLM-L6-v2")
|
||||
|
||||
# Thresholds
|
||||
BEST_FUZZY_THRESHOLD = 85
|
||||
LOWER_FUZZY_THRESHOLD = 40
|
||||
EMBED_SIM_THRESHOLD = 0.65
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
def process_data(input_data):
|
||||
os.makedirs(MODEL_PATH, exist_ok=True)
|
||||
|
||||
# If not present locally, download:
|
||||
if not os.path.exists(os.path.join(MODEL_PATH, "config.json")):
|
||||
eprint(f"Local model not found in {MODEL_PATH}; downloading from {SENTENCE_MODEL_NAME}...")
|
||||
st_model = st(SENTENCE_MODEL_NAME, cache_folder=MODEL_PATH)
|
||||
else:
|
||||
eprint(f"Loading local model from {MODEL_PATH}")
|
||||
st_model = st(MODEL_PATH)
|
||||
|
||||
channels = input_data["channels"]
|
||||
epg_data = input_data["epg_data"]
|
||||
region_code = input_data["region_code"]
|
||||
|
||||
epg_embeddings = None
|
||||
if any(row["norm_name"] for row in epg_data):
|
||||
epg_embeddings = st_model.encode(
|
||||
[row["norm_name"] for row in epg_data],
|
||||
convert_to_tensor=True
|
||||
)
|
||||
|
||||
channels_to_update = []
|
||||
matched_channels = []
|
||||
|
||||
for chan in channels:
|
||||
# If channel has a tvg_id that doesn't exist in EPGData, do direct check.
|
||||
# I don't THINK this should happen now that we assign EPG on channel creation.
|
||||
if chan["tvg_id"]:
|
||||
epg_match = [epg["id"] for epg in epg_data if epg["tvg_id"] == chan["tvg_id"]]
|
||||
if epg_match:
|
||||
chan["epg_data_id"] = epg_match[0]["id"]
|
||||
eprint(f"Channel {chan['id']} '{chan['name']}' => EPG found by tvg_id={chan['tvg_id']}")
|
||||
channels_to_update.append(chan)
|
||||
continue
|
||||
|
||||
# C) Perform name-based fuzzy matching
|
||||
fallback_name = chan["tvg_id"].strip() if chan["tvg_id"] else chan["name"]
|
||||
if not chan["norm_chan"]:
|
||||
eprint(f"Channel {chan['id']} '{chan['name']}' => empty after normalization, skipping")
|
||||
continue
|
||||
|
||||
best_score = 0
|
||||
best_epg = None
|
||||
for row in epg_data:
|
||||
if not row["norm_name"]:
|
||||
continue
|
||||
|
||||
base_score = fuzz.ratio(chan["norm_chan"], row["norm_name"])
|
||||
bonus = 0
|
||||
# Region-based bonus/penalty
|
||||
combined_text = row["tvg_id"].lower() + " " + row["name"].lower()
|
||||
dot_regions = re.findall(r'\.([a-z]{2})', combined_text)
|
||||
if region_code:
|
||||
if dot_regions:
|
||||
if region_code in dot_regions:
|
||||
bonus = 30 # bigger bonus if .us or .ca matches
|
||||
else:
|
||||
bonus = -15
|
||||
elif region_code in combined_text:
|
||||
bonus = 15
|
||||
score = base_score + bonus
|
||||
|
||||
eprint(
|
||||
f"Channel {chan['id']} '{fallback_name}' => EPG row {row['id']}: "
|
||||
f"name='{row['name']}', norm_name='{row['norm_name']}', "
|
||||
f"combined_text='{combined_text}', dot_regions={dot_regions}, "
|
||||
f"base_score={base_score}, bonus={bonus}, total_score={score}"
|
||||
)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_epg = row
|
||||
|
||||
# If no best match was found, skip
|
||||
if not best_epg:
|
||||
eprint(f"Channel {chan['id']} '{fallback_name}' => no EPG match at all.")
|
||||
continue
|
||||
|
||||
# If best_score is above BEST_FUZZY_THRESHOLD => direct accept
|
||||
if best_score >= BEST_FUZZY_THRESHOLD:
|
||||
chan["epg_data_id"] = best_epg["id"]
|
||||
channels_to_update.append(chan)
|
||||
|
||||
matched_channels.append((chan['id'], fallback_name, best_epg["tvg_id"]))
|
||||
eprint(
|
||||
f"Channel {chan['id']} '{fallback_name}' => matched tvg_id={best_epg['tvg_id']} "
|
||||
f"(score={best_score})"
|
||||
)
|
||||
|
||||
# If best_score is in the “middle range,” do embedding check
|
||||
elif best_score >= LOWER_FUZZY_THRESHOLD and epg_embeddings is not None:
|
||||
chan_embedding = st_model.encode(chan["norm_chan"], convert_to_tensor=True)
|
||||
sim_scores = util.cos_sim(chan_embedding, epg_embeddings)[0]
|
||||
top_index = int(sim_scores.argmax())
|
||||
top_value = float(sim_scores[top_index])
|
||||
if top_value >= EMBED_SIM_THRESHOLD:
|
||||
matched_epg = epg_data[top_index]
|
||||
chan["epg_data_id"] = matched_epg["id"]
|
||||
channels_to_update.append(chan)
|
||||
|
||||
matched_channels.append((chan['id'], fallback_name, matched_epg["tvg_id"]))
|
||||
eprint(
|
||||
f"Channel {chan['id']} '{fallback_name}' => matched EPG tvg_id={matched_epg['tvg_id']} "
|
||||
f"(fuzzy={best_score}, cos-sim={top_value:.2f})"
|
||||
)
|
||||
else:
|
||||
eprint(
|
||||
f"Channel {chan['id']} '{fallback_name}' => fuzzy={best_score}, "
|
||||
f"cos-sim={top_value:.2f} < {EMBED_SIM_THRESHOLD}, skipping"
|
||||
)
|
||||
else:
|
||||
eprint(
|
||||
f"Channel {chan['id']} '{fallback_name}' => fuzzy={best_score} < "
|
||||
f"{LOWER_FUZZY_THRESHOLD}, skipping"
|
||||
)
|
||||
|
||||
return {
|
||||
"channels_to_update": channels_to_update,
|
||||
"matched_channels": matched_channels,
|
||||
}
|
||||
|
||||
def main():
|
||||
# Read input data from a file
|
||||
input_file_path = sys.argv[1]
|
||||
with open(input_file_path, 'r') as f:
|
||||
input_data = json.load(f)
|
||||
|
||||
# Process data with the ML model (or your logic)
|
||||
result = process_data(input_data)
|
||||
|
||||
# Output result to stdout
|
||||
print(json.dumps(result))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue