From 9b6f6073e0bcc6a16a38c2bc6c48f12d8e787c4e Mon Sep 17 00:00:00 2001 From: SergeantPanda Date: Fri, 28 Feb 2025 11:04:05 -0600 Subject: [PATCH 01/14] Initial Proxy release --- apps/proxy/hls_proxy | 627 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 628 insertions(+) create mode 100644 apps/proxy/hls_proxy diff --git a/apps/proxy/hls_proxy b/apps/proxy/hls_proxy new file mode 100644 index 00000000..34dfb1ea --- /dev/null +++ b/apps/proxy/hls_proxy @@ -0,0 +1,627 @@ +""" +HLS Proxy Server with Advanced Stream Switching Support +This proxy handles HLS live streams with support for: +- Stream switching with proper discontinuity handling +- Buffer management +- Segment validation +- Connection pooling and reuse +""" + +from flask import Flask, Response, request, jsonify +import requests +import threading +import logging +import m3u8 +import time +from urllib.parse import urlparse, urljoin +import argparse +from typing import Optional +import sys + +# Initialize Flask app +app = Flask(__name__) + +# Global state management +manifest_buffer = None # Stores current manifest content +segment_buffers = {} # Maps sequence numbers to segment data +buffer_lock = threading.Lock() # Synchronizes access to buffers + +class Config: + """Configuration settings for stream handling and buffering""" + # Buffer size settings + MIN_SEGMENTS = 12 # Minimum segments to maintain + MAX_SEGMENTS = 16 # Maximum segments to store + WINDOW_SIZE = 12 # Number of segments in manifest window + INITIAL_SEGMENTS = 3 # Initial segments to buffer before playback + +class StreamFetcher: + """Handles HTTP requests for stream segments with connection pooling""" + def __init__(self, stream_url): + self.stream_url = stream_url + self.session = requests.Session() + + # Set up connection pooling + adapter = requests.adapters.HTTPAdapter( + pool_connections=2, # Number of connection pools + pool_maxsize=4, # Connections per pool + max_retries=3, # Auto-retry failed requests + pool_block=False # Don't block when pool is full + ) + + # Apply adapter to both HTTP and HTTPS + self.session.mount('http://', adapter) + self.session.mount('https://', adapter) + + # Request optimization + self.last_request_time = 0 + self.min_request_interval = 0.05 # Minimum time between requests + self.last_host = None # Cache last successful host + self.redirect_cache = {} # Cache redirect responses + + def get_base_host(self, url): + """Extract base host from URL using urlparse""" + try: + parsed = urlparse(url) + return f"{parsed.scheme}://{parsed.netloc}" + except Exception as e: + logging.error(f"Error extracting base host: {e}") + return url + + def download(self, url): + """Download content with connection reuse""" + now = time.time() + wait_time = self.last_request_time + self.min_request_interval - now + if (wait_time > 0): + time.sleep(wait_time) + + try: + # Use cached redirect if available + if url in self.redirect_cache: + logging.debug(f"Using cached redirect for {url}") + final_url = self.redirect_cache[url] + response = self.session.get(final_url, timeout=10) + else: + response = self.session.get(url, allow_redirects=True, timeout=10) + if response.history: # Cache redirects + logging.debug(f"Caching redirect for {url} -> {response.url}") + self.redirect_cache[url] = response.url + + self.last_request_time = time.time() + + if response.status_code == 200: + self.last_host = self.get_base_host(response.url) + + return response.content, response.url + + except Exception as e: + logging.error(f"Download error: {e}") + if self.last_host and not url.startswith(self.last_host): + # Use urljoin to handle path resolution + new_url = urljoin(self.last_host + '/', url.split('://')[-1].split('/', 1)[-1]) + logging.debug(f"Retrying with last host: {new_url}") + return self.download(new_url) + raise + +def analyze_ts_packet(data: bytes) -> dict: + """ + Analyze a single MPEG-TS packet (188 bytes) + + Args: + data: Raw packet bytes + + Returns: + dict with packet analysis: + - sync_valid: True if sync byte is 0x47 + - transport_error: True if TEI bit is set + - payload_start: True if PUSI bit is set + - pid: Packet ID (13 bits) + - hex_dump: First 16 bytes for debugging + """ + # Validate minimum packet size + if len(data) < 188: + return { + 'sync_valid': False, + 'transport_error': True, + 'payload_start': False, + 'pid': 0, + 'error': 'Packet too short' + } + + # Verify sync byte (0x47) + if data[0] != 0x47: + return { + 'sync_valid': False, + 'transport_error': True, + 'payload_start': False, + 'pid': 0, + 'error': 'Invalid sync byte' + } + + # Extract packet header fields + transport_error = (data[1] & 0x80) != 0 # Transport Error Indicator + payload_start = (data[1] & 0x40) != 0 # Payload Unit Start Indicator + pid = ((data[1] & 0x1F) << 8) | data[2] # Packet ID (13 bits) + + # Create hex dump for debugging + hex_dump = ' '.join(f'{b:02x}' for b in data[:16]) + + return { + 'sync_valid': True, + 'transport_error': transport_error, + 'payload_start': payload_start, + 'pid': pid, + 'hex_dump': hex_dump, + 'packet_size': len(data) + } + +def get_segment_sequence(segment_uri: str) -> Optional[int]: + """ + Extract sequence number from segment URI + + Args: + segment_uri: Segment filename or path + + Returns: + int: Sequence number if found + None: If no sequence can be extracted + """ + try: + # Try numerical sequence (e.g., 1038_3693.ts) + if '_' in segment_uri: + return int(segment_uri.split('_')[-1].split('.')[0]) + return None + except ValueError: + return None + +# Update verify_segment with more thorough checks +def verify_segment(data: bytes) -> dict: + """ + Verify MPEG-TS segment integrity + + Args: + data: Raw segment bytes + + Returns: + dict with verification results: + - valid: True if segment passes all checks + - packets: Number of valid packets found + - size: Total segment size in bytes + - error: Description if validation fails + """ + # Check minimum size + if len(data) < 188: + return {'valid': False, 'error': 'Segment too short'} + + # Verify segment size is multiple of packet size + if len(data) % 188 != 0: + return {'valid': False, 'error': 'Invalid segment size'} + + valid_packets = 0 + total_packets = len(data) // 188 + + # Scan all packets in segment + for i in range(0, len(data), 188): + packet = data[i:i+188] + + # Check packet completeness + if len(packet) != 188: + return {'valid': False, 'error': 'Incomplete packet'} + + # Verify sync byte + if packet[0] != 0x47: + return {'valid': False, 'error': f'Invalid sync byte at offset {i}'} + + # Check transport error indicator + if packet[1] & 0x80: + return {'valid': False, 'error': 'Transport error indicator set'} + + valid_packets += 1 + + return { + 'valid': True, + 'packets': valid_packets, + 'size': len(data) + } + +class StreamManager: + """Manages HLS stream state and switching logic""" + def __init__(self, initial_url: str): + # Stream state + self.current_url = initial_url + self.running = True + self.switching_stream = False + + # Sequence tracking + self.next_sequence = 0 + self.highest_sequence = 0 + self.buffered_sequences = set() + self.downloaded_sources = {} + self.segment_durations = {} + + # Source tracking + self.current_source = None + self.source_changes = set() + self.stream_switch_count = 0 + + # Threading + self.fetcher = None + self.fetch_thread = None + self.url_changed = threading.Event() + + def update_url(self, new_url: str) -> bool: + """Handle stream URL changes with proper discontinuity marking""" + if new_url != self.current_url: + with buffer_lock: + self.switching_stream = True + self.current_url = new_url + + # Set sequence numbers for stream switch + if segment_buffers: + self.highest_sequence = max(segment_buffers.keys()) + self.next_sequence = self.highest_sequence + 1 + # Mark discontinuity at first segment of new stream + self.source_changes = {self.next_sequence + 1} + else: + self.stream_switch_count += 1 + self.next_sequence = self.stream_switch_count * 1000 + self.source_changes = {self.next_sequence} + + logging.info(f"Stream switch - next sequence will start at {self.next_sequence}") + + # Clear state but maintain sequence numbers + self.downloaded_sources.clear() + self.segment_durations.clear() + self.current_source = None + + # Signal thread to switch URL + self.url_changed.set() + + return True + return False + + def get_next_sequence(self, source_id): + """Assign sequence numbers to segments with source change detection""" + if source_id in self.downloaded_sources: + return None + + seq = self.next_sequence + while seq in self.buffered_sequences: + seq += 1 + + # Track source changes for discontinuity markers + source_prefix = source_id.split('_')[0] + if not self.switching_stream and self.current_source and self.current_source != source_prefix: + self.source_changes.add(seq) + logging.debug(f"Source change detected at sequence {seq}") + self.current_source = source_prefix + + # Update tracking + self.downloaded_sources[source_id] = seq + self.buffered_sequences.add(seq) + self.next_sequence = seq + 1 + self.highest_sequence = max(self.highest_sequence, seq) + + return seq + + def _fetch_loop(self): + """Background thread for continuous stream fetching""" + while self.running: + try: + self.fetcher = StreamFetcher(self.current_url) + fetch_stream(self.fetcher, self.url_changed, self.next_sequence) + except Exception as e: + logging.error(f"Stream error: {e}") + time.sleep(5) # Wait before retry + + self.url_changed.clear() + + def start(self): + """Start the background fetch thread""" + if not self.fetch_thread or not self.fetch_thread.is_alive(): + self.running = True + self.fetch_thread = threading.Thread( + target=self._fetch_loop, + name="StreamFetcher", + daemon=True # Thread will exit when main program does + ) + self.fetch_thread.start() + logging.info("Stream manager started") + + def stop(self): + """Stop the background fetch thread""" + self.running = False + if self.fetch_thread and self.fetch_thread.is_alive(): + self.url_changed.set() # Signal thread to exit + self.fetch_thread.join(timeout=5) # Wait up to 5 seconds + logging.info("Stream manager stopped") + +def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequence: int = 0): + global manifest_buffer, segment_buffers + retry_delay = 1 + max_retry_delay = 8 + last_segment_time = 0 + buffer_initialized = False + manifest_update_needed = True + segment_duration = None + + while not stop_event.is_set(): + try: + now = time.time() + + # Only update manifest when it's time for next segment + should_update = ( + manifest_update_needed or + not segment_duration or + (last_segment_time and (now - last_segment_time) >= segment_duration * 0.8) + ) + + if should_update: + manifest_data, final_url = fetcher.download(fetcher.stream_url) + manifest = m3u8.loads(manifest_data.decode()) + + if not manifest.segments: + continue + + with buffer_lock: + manifest_content = manifest_data.decode() + new_segments = {} + + if stream_manager.switching_stream: + # Stream switch - only get latest segment + manifest_segments = [manifest.segments[-1]] + seq_start = stream_manager.next_sequence + max_segments = 1 + logging.debug(f"Processing stream switch - getting latest segment at sequence {seq_start}") + elif not buffer_initialized: + # Initial buffer + manifest_segments = manifest.segments[-Config.INITIAL_SEGMENTS:] + seq_start = stream_manager.next_sequence + max_segments = Config.INITIAL_SEGMENTS + logging.debug(f"Starting initial buffer at sequence {seq_start}") + else: + # Normal operation + manifest_segments = [manifest.segments[-1]] + seq_start = stream_manager.next_sequence + max_segments = 1 + + # Map segments + segments_mapped = 0 + for segment in manifest_segments: + if segments_mapped >= max_segments: + break + + source_id = segment.uri.split('/')[-1].split('.')[0] + next_seq = stream_manager.get_next_sequence(source_id) + + if next_seq is not None: + duration = float(segment.duration) + new_segments[next_seq] = { + 'uri': segment.uri, + 'duration': duration, + 'source_id': source_id + } + stream_manager.segment_durations[next_seq] = duration + segments_mapped += 1 + + manifest_buffer = manifest_content + + # Download segments + for sequence_id, segment_info in new_segments.items(): + try: + segment_url = f"{fetcher.last_host}{segment_info['uri']}" + logging.debug(f"Downloading {segment_info['uri']} as segment {sequence_id}.ts " + f"(source: {segment_info['source_id']}, duration: {segment_info['duration']:.3f}s)") + + segment_data, _ = fetcher.download(segment_url) + validation = verify_segment(segment_data) + + if validation.get('valid', False): + with buffer_lock: + segment_buffers[sequence_id] = segment_data + logging.debug(f"Downloaded and verified segment {sequence_id} (packets: {validation['packets']})") + + if stream_manager.switching_stream: + stream_manager.switching_stream = False + stop_event.set() # Force fetcher restart with new URL + break + elif not buffer_initialized and len(segment_buffers) >= Config.INITIAL_SEGMENTS: + buffer_initialized = True + manifest_update_needed = True + break + except Exception as e: + logging.error(f"Segment download error: {e}") + continue + + else: + # Short sleep to prevent CPU spinning + threading.Event().wait(0.1) + + except Exception as e: + logging.error(f"Manifest error: {e}") + threading.Event().wait(retry_delay) + retry_delay = min(retry_delay * 2, max_retry_delay) + manifest_update_needed = True + +# Flask Routes for HLS Proxy Server +@app.route('/stream.m3u8') +def master_playlist(): + """ + Serve the HLS master playlist + + Handles: + - Initial buffering state + - Ongoing playback with sliding window + - Discontinuity markers for stream switches + - Dynamic segment durations + """ + with buffer_lock: + # Verify buffer state + if not manifest_buffer or not segment_buffers: + logging.warning("No manifest or segments available yet") + return '', 404 + + available = sorted(segment_buffers.keys()) + if not available: + logging.warning("No segments available") + return '', 404 + + manifest = m3u8.loads(manifest_buffer) + max_seq = max(available) + + # Calculate window bounds + if len(available) <= Config.INITIAL_SEGMENTS: + # During initial buffering, show all segments + min_seq = min(available) + else: + # For ongoing playback, maintain sliding window + min_seq = max( + min(available), + max_seq - Config.WINDOW_SIZE + 1 + ) + + # Build manifest with proper tags + new_manifest = ['#EXTM3U'] + new_manifest.append('#EXT-X-VERSION:3') + new_manifest.append(f'#EXT-X-MEDIA-SEQUENCE:{min_seq}') + new_manifest.append(f'#EXT-X-TARGETDURATION:{int(manifest.target_duration)}') + + # Filter segments within window + window_segments = [s for s in available if min_seq <= s <= max_seq] + + # Add segments with discontinuity handling + for seq in window_segments: + # Mark stream switches with discontinuity + if seq in stream_manager.source_changes: + new_manifest.append('#EXT-X-DISCONTINUITY') + logging.debug(f"Added discontinuity marker before segment {seq}") + + # Use actual segment duration or fallback to target + duration = stream_manager.segment_durations.get(seq, manifest.target_duration) + new_manifest.append(f'#EXTINF:{duration:.3f},') + new_manifest.append(f'/segments/{seq}.ts') + + manifest_content = '\n'.join(new_manifest) + logging.debug(f"Serving manifest with segments {min_seq}-{max_seq} (window: {len(window_segments)})") + return Response(manifest_content, content_type='application/vnd.apple.mpegurl') + +@app.route('/segments/') +def get_segment(segment_name): + """ + Serve individual MPEG-TS segments + + Args: + segment_name: Segment filename (e.g., '123.ts') + + Returns: + MPEG-TS segment data or 404 if not found + """ + try: + segment_id = int(segment_name.split('.')[0]) + with buffer_lock: + if segment_id in segment_buffers: + available = sorted(segment_buffers.keys()) + logging.debug(f"Client requested segment {segment_id} (buffer: {min(available)}-{max(available)})") + return Response(segment_buffers[segment_id], content_type='video/MP2T') + + logging.warning( + f"Segment {segment_id} not found. " + f"Available: {min(segment_buffers.keys()) if segment_buffers else 'none'}" + f"-{max(segment_buffers.keys()) if segment_buffers else 'none'}" + ) + except Exception as e: + logging.error(f"Error serving segment {segment_name}: {e}") + return '', 404 + +@app.route('/change_stream', methods=['POST']) +def change_stream(): + """ + Handle stream URL changes via HTTP POST + + Expected JSON body: + {"url": "new_stream_url"} + + Returns: + JSON response indicating success/failure + """ + new_url = request.json.get('url') + if not new_url: + return jsonify({'error': 'No URL provided'}), 400 + + if stream_manager.update_url(new_url): + return jsonify({'message': 'Stream URL updated', 'url': new_url}) + return jsonify({'message': 'URL unchanged', 'url': new_url}) + +@app.before_request +def log_request_info(): + """Log client connections and important requests""" + if request.path == '/stream.m3u8' and not segment_buffers: + # First manifest request from a client + logging.info(f"New client connected from {request.remote_addr}") + elif request.path.startswith('/change_stream'): + # Keep stream switch requests as INFO + logging.info(f"Stream switch requested from {request.remote_addr}") + else: + # Move routine requests to DEBUG + logging.debug(f"{request.remote_addr} - {request.method} {request.path}") + +# Configure Werkzeug logger to DEBUG +logging.getLogger('werkzeug').setLevel(logging.DEBUG) + +# Main Application Setup +if __name__ == '__main__': + # Command line argument parsing + parser = argparse.ArgumentParser(description='HLS Proxy Server with Stream Switching') + parser.add_argument( + '--url', '-u', + default='http://example.com/stream.m3u8', + help='Initial HLS stream URL to proxy' + ) + parser.add_argument( + '--port', '-p', + type=int, + default=5000, + help='Local port to serve proxy on (default: 5000)' + ) + parser.add_argument( + '--host', '-H', + default='0.0.0.0', + help='Interface to bind server to (default: all interfaces)' + ) + parser.add_argument( + '--debug', + action='store_true', + help='Enable debug logging' + ) + args = parser.parse_args() + + # Configure logging with separate format for access logs + logging.basicConfig( + level=logging.DEBUG if args.debug else logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # Initialize proxy components + try: + # Create and start stream manager + stream_manager = StreamManager(args.url) + stream_manager.start() + + logging.info(f"Starting HLS proxy server on {args.host}:{args.port}") + logging.info(f"Initial stream URL: {args.url}") + + # Run Flask development server + # Note: For production, use a proper WSGI server like gunicorn + app.run( + host=args.host, + port=args.port, + threaded=True, # Enable multi-threading for segment handling + debug=args.debug + ) + except Exception as e: + logging.error(f"Failed to start server: {e}") + if stream_manager: + stream_manager.running = False + if stream_manager.fetch_thread: + stream_manager.fetch_thread.join() + sys.exit(1) diff --git a/requirements.txt b/requirements.txt index d479a65c..a2fbbc7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ yt-dlp gevent==24.11.1 django-cors-headers djangorestframework-simplejwt +m3u8 \ No newline at end of file From 5b25e2cc6af019c4732faa3e3e2dfe4883caab66 Mon Sep 17 00:00:00 2001 From: SergeantPanda Date: Fri, 28 Feb 2025 13:20:43 -0600 Subject: [PATCH 02/14] Added TS Proxy server --- apps/proxy/ts_proxy | 323 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 apps/proxy/ts_proxy diff --git a/apps/proxy/ts_proxy b/apps/proxy/ts_proxy new file mode 100644 index 00000000..ec6d62d7 --- /dev/null +++ b/apps/proxy/ts_proxy @@ -0,0 +1,323 @@ +""" +Transport Stream (TS) Proxy Server +Handles live TS stream proxying with support for: +- Stream switching +- Buffer management +- Multiple client connections +- Connection state tracking +""" + +from flask import Flask, Response, request, jsonify +import requests +import threading +import logging +from collections import deque +import time +import os +from typing import Optional, Set, Deque, Dict + +# Configuration +class Config: + CHUNK_SIZE: int = 8192 # Buffer chunk size (bytes) + BUFFER_SIZE: int = 1000 # Number of chunks to keep in memory + RECONNECT_DELAY: int = 5 # Seconds between reconnection attempts + CLIENT_POLL_INTERVAL: float = 0.1 # Seconds between client buffer checks + MAX_RETRIES: int = 3 # Maximum connection retry attempts + DEFAULT_USER_AGENT: str = 'VLC/3.0.20 LibVLC/3.0.20' # Default user agent + +class StreamManager: + """Manages TS stream state and connection handling""" + + def __init__(self, initial_url: str, channel_id: str, user_agent: Optional[str] = None): + self.current_url: str = initial_url + self.channel_id: str = channel_id + self.user_agent: str = user_agent or Config.DEFAULT_USER_AGENT + self.url_changed: threading.Event = threading.Event() + self.running: bool = True + self.session: requests.Session = self._create_session() + self.connected: bool = False + self.retry_count: int = 0 + logging.info(f"Initialized stream manager for channel {channel_id}") + + def _create_session(self) -> requests.Session: + """Create and configure requests session""" + session = requests.Session() + session.headers.update({ + 'User-Agent': self.user_agent, + 'Connection': 'keep-alive' + }) + return session + + def update_url(self, new_url: str) -> bool: + """Update stream URL and signal connection change""" + if new_url != self.current_url: + logging.info(f"Stream switch initiated: {self.current_url} -> {new_url}") + self.current_url = new_url + self.connected = False + self.url_changed.set() + return True + return False + + def should_retry(self) -> bool: + """Check if connection retry is allowed""" + return self.retry_count < Config.MAX_RETRIES + + def stop(self) -> None: + """Clean shutdown of stream manager""" + self.running = False + if self.session: + self.session.close() + +class StreamBuffer: + """Manages stream data buffering""" + + def __init__(self): + self.buffer: Deque[bytes] = deque(maxlen=Config.BUFFER_SIZE) + self.lock: threading.Lock = threading.Lock() + self.index: int = 0 + +class ClientManager: + """Manages active client connections""" + + def __init__(self): + self.active_clients: Set[int] = set() + self.lock: threading.Lock = threading.Lock() + + def add_client(self, client_id: int) -> None: + """Add new client connection""" + with self.lock: + self.active_clients.add(client_id) + logging.info(f"New client connected: {client_id} (total: {len(self.active_clients)})") + + def remove_client(self, client_id: int) -> int: + """Remove client and return remaining count""" + with self.lock: + self.active_clients.remove(client_id) + remaining = len(self.active_clients) + logging.info(f"Client disconnected: {client_id} (remaining: {remaining})") + return remaining + +class StreamFetcher: + """Handles stream data fetching""" + + def __init__(self, manager: StreamManager, buffer: StreamBuffer): + self.manager = manager + self.buffer = buffer + + def fetch_loop(self) -> None: + """Main fetch loop for stream data""" + while self.manager.running: + try: + if not self._handle_connection(): + continue + + with self.manager.session.get(self.manager.current_url, stream=True) as response: + if response.status_code == 200: + self._handle_successful_connection() + self._process_stream(response) + + except requests.exceptions.RequestException as e: + self._handle_connection_error(e) + + def _handle_connection(self) -> bool: + """Handle connection state and retries""" + if not self.manager.connected: + if not self.manager.should_retry(): + logging.error(f"Failed to connect after {Config.MAX_RETRIES} attempts") + return False + + if not self.manager.running: + return False + + self.manager.retry_count += 1 + logging.info(f"Connecting to stream: {self.manager.current_url} " + f"(attempt {self.manager.retry_count}/{Config.MAX_RETRIES})") + return True + + def _handle_successful_connection(self) -> None: + """Handle successful stream connection""" + if not self.manager.connected: + logging.info("Stream connected successfully") + self.manager.connected = True + self.manager.retry_count = 0 + + def _process_stream(self, response: requests.Response) -> None: + """Process incoming stream data""" + for chunk in response.iter_content(chunk_size=Config.CHUNK_SIZE): + if not self.manager.running: + logging.info("Stream fetch stopped - shutting down") + return + + if chunk: + if self.manager.url_changed.is_set(): + logging.info("Stream switch in progress, closing connection") + self.manager.url_changed.clear() + break + + with self.buffer.lock: + self.buffer.buffer.append(chunk) + self.buffer.index += 1 + + def _handle_connection_error(self, error: Exception) -> None: + """Handle stream connection errors""" + logging.error(f"Stream connection error: {error}") + self.manager.connected = False + + if not self.manager.running: + return + + logging.info(f"Attempting to reconnect in {Config.RECONNECT_DELAY} seconds...") + if not wait_for_running(self.manager, Config.RECONNECT_DELAY): + return + +def wait_for_running(manager: StreamManager, delay: float) -> bool: + """Wait while checking manager running state""" + start = time.time() + while time.time() - start < delay: + if not manager.running: + return False + threading.Event().wait(0.1) + return True + +class ProxyServer: + """Manages TS proxy server instance""" + + def __init__(self, user_agent: Optional[str] = None): + self.app = Flask(__name__) + self.stream_managers: Dict[str, StreamManager] = {} + self.stream_buffers: Dict[str, StreamBuffer] = {} + self.client_managers: Dict[str, ClientManager] = {} + self.fetch_threads: Dict[str, threading.Thread] = {} + self.user_agent: str = user_agent or Config.DEFAULT_USER_AGENT + self._setup_routes() + + def _setup_routes(self) -> None: + """Configure Flask routes""" + self.app.route('/stream/')(self.stream_endpoint) + self.app.route('/change_stream/', methods=['POST'])(self.change_stream) + + def initialize_channel(self, url: str, channel_id: str) -> None: + """Initialize a new channel stream""" + if channel_id in self.stream_managers: + self.stop_channel(channel_id) + + self.stream_managers[channel_id] = StreamManager( + url, + channel_id, + user_agent=self.user_agent + ) + self.stream_buffers[channel_id] = StreamBuffer() + self.client_managers[channel_id] = ClientManager() + + fetcher = StreamFetcher( + self.stream_managers[channel_id], + self.stream_buffers[channel_id] + ) + + self.fetch_threads[channel_id] = threading.Thread( + target=fetcher.fetch_loop, + name=f"StreamFetcher-{channel_id}", + daemon=True + ) + self.fetch_threads[channel_id].start() + logging.info(f"Initialized channel {channel_id} with URL {url}") + + def stop_channel(self, channel_id: str) -> None: + """Stop and cleanup a channel""" + if channel_id in self.stream_managers: + self.stream_managers[channel_id].stop() + if channel_id in self.fetch_threads: + self.fetch_threads[channel_id].join(timeout=5) + self._cleanup_channel(channel_id) + + def _cleanup_channel(self, channel_id: str) -> None: + """Remove channel resources""" + for collection in [self.stream_managers, self.stream_buffers, + self.client_managers, self.fetch_threads]: + collection.pop(channel_id, None) + + def stream_endpoint(self, channel_id: str): + """Stream endpoint that serves TS data to clients""" + if channel_id not in self.stream_managers: + return Response('Channel not found', status=404) + + def generate(): + client_id = threading.get_ident() + buffer = self.stream_buffers[channel_id] + client_manager = self.client_managers[channel_id] + + client_manager.add_client(client_id) + last_index = buffer.index + + try: + while True: + with buffer.lock: + if buffer.index > last_index: + chunks_behind = buffer.index - last_index + start_pos = max(0, len(buffer.buffer) - chunks_behind) + + for i in range(start_pos, len(buffer.buffer)): + yield buffer.buffer[i] + last_index = buffer.index + + threading.Event().wait(Config.CLIENT_POLL_INTERVAL) + except GeneratorExit: + remaining = client_manager.remove_client(client_id) + if remaining == 0: + logging.info(f"No clients remaining for channel {channel_id}") + self.stop_channel(channel_id) + + return Response(generate(), content_type='video/mp2t') + + def change_stream(self, channel_id: str): + """Handle stream URL changes""" + if channel_id not in self.stream_managers: + return jsonify({'error': 'Channel not found'}), 404 + + new_url = request.json.get('url') + if not new_url: + return jsonify({'error': 'No URL provided'}), 400 + + manager = self.stream_managers[channel_id] + if manager.update_url(new_url): + return jsonify({ + 'message': 'Stream URL updated', + 'channel': channel_id, + 'url': new_url + }) + return jsonify({ + 'message': 'URL unchanged', + 'channel': channel_id, + 'url': new_url + }) + + def run(self, host: str = '0.0.0.0', port: int = 5000) -> None: + """Start the proxy server""" + self.app.run(host=host, port=port, threaded=True) + + def shutdown(self) -> None: + """Stop all channels and cleanup""" + for channel_id in list(self.stream_managers.keys()): + self.stop_channel(channel_id) + +def main(): + """Initialize and start the proxy server""" + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + logging.getLogger('werkzeug').setLevel(logging.DEBUG) + + proxy_server = ProxyServer() + initial_url = os.getenv('STREAM_URL', 'http://example.com/stream.ts') + proxy_server.initialize_channel(initial_url, "default_channel") + + try: + proxy_server.run() + finally: + proxy_server.shutdown() + +if __name__ == '__main__': + main() From bcd970af960a63bf17d42906a46440cb7dbcbe6a Mon Sep 17 00:00:00 2001 From: SergeantPanda Date: Fri, 28 Feb 2025 15:56:47 -0600 Subject: [PATCH 03/14] Rebuilt hls_proxy to be more in line with ts_proxy for routes --- apps/proxy/hls_proxy | 1089 ++++++++++++++++++++++++++++++------------ 1 file changed, 775 insertions(+), 314 deletions(-) diff --git a/apps/proxy/hls_proxy b/apps/proxy/hls_proxy index 34dfb1ea..380e464d 100644 --- a/apps/proxy/hls_proxy +++ b/apps/proxy/hls_proxy @@ -15,8 +15,9 @@ import m3u8 import time from urllib.parse import urlparse, urljoin import argparse -from typing import Optional +from typing import Optional, Dict, List, Set, Deque import sys +import os # Initialize Flask app app = Flask(__name__) @@ -33,13 +34,278 @@ class Config: MAX_SEGMENTS = 16 # Maximum segments to store WINDOW_SIZE = 12 # Number of segments in manifest window INITIAL_SEGMENTS = 3 # Initial segments to buffer before playback + DEFAULT_USER_AGENT = 'VLC/3.0.20 LibVLC/3.0.20' + +class StreamBuffer: + """ + Manages buffering of stream segments with thread-safe access. + + Attributes: + buffer (Dict[int, bytes]): Maps sequence numbers to segment data + lock (threading.Lock): Thread safety for buffer access + + Features: + - Thread-safe segment storage and retrieval + - Automatic cleanup of old segments + - Sequence number based indexing + """ + + def __init__(self): + self.buffer: Dict[int, bytes] = {} # Maps sequence numbers to segment data + self.lock: threading.Lock = threading.Lock() + + def __getitem__(self, key: int) -> Optional[bytes]: + """Get segment data by sequence number""" + return self.buffer.get(key) + + def __setitem__(self, key: int, value: bytes): + """Store segment data by sequence number""" + self.buffer[key] = value + # Cleanup old segments if we exceed MAX_SEGMENTS + if len(self.buffer) > Config.MAX_SEGMENTS: + keys = sorted(self.buffer.keys()) + # Keep the most recent MAX_SEGMENTS + to_remove = keys[:-Config.MAX_SEGMENTS] + for k in to_remove: + del self.buffer[k] + + def __contains__(self, key: int) -> bool: + """Check if sequence number exists in buffer""" + return key in self.buffer + + def keys(self) -> List[int]: + """Get list of available sequence numbers""" + return list(self.buffer.keys()) + + def cleanup(self, keep_sequences: List[int]): + """Remove segments not in keep list""" + for seq in list(self.buffer.keys()): + if seq not in keep_sequences: + del self.buffer[seq] + +class ClientManager: + """Manages client connections""" + + def __init__(self): + """ + Initialize client connection tracking. + + Attributes: + active_clients (Set[int]): Set of active client IDs + lock (threading.Lock): Thread safety for client operations + """ + self.active_clients: Set[int] = set() + self.lock: threading.Lock = threading.Lock() + + def add_client(self, client_id: int): + """Add new client connection""" + with self.lock: + self.active_clients.add(client_id) + logging.info(f"New client connected (total: {len(self.active_clients)})") + + def remove_client(self, client_id: int) -> int: + """Remove client and return remaining count""" + with self.lock: + self.active_clients.remove(client_id) + remaining = len(self.active_clients) + logging.info(f"Client disconnected (remaining: {remaining})") + return remaining + + def has_clients(self) -> bool: + """Check if any clients are connected""" + return len(self.active_clients) > 0 + +class StreamManager: + """ + Manages HLS stream state and switching logic. + + Attributes: + current_url (str): Current stream URL + channel_id (str): Unique channel identifier + running (bool): Stream activity flag + next_sequence (int): Next sequence number to assign + buffered_sequences (set): Currently buffered sequence numbers + source_changes (set): Sequences where stream source changed + + Features: + - Stream URL management + - Sequence number assignment + - Discontinuity tracking + - Thread coordination + - Buffer state management + """ + def __init__(self, initial_url: str, channel_id: str, user_agent: Optional[str] = None): + # Stream state + self.current_url = initial_url + self.channel_id = channel_id + self.user_agent = user_agent or Config.DEFAULT_USER_AGENT + self.running = True + self.switching_stream = False + + # Sequence tracking + self.next_sequence = 0 + self.highest_sequence = 0 + self.buffered_sequences = set() + self.downloaded_sources = {} + self.segment_durations = {} + + # Source tracking + self.current_source = None + self.source_changes = set() + self.stream_switch_count = 0 + + # Threading + self.fetcher = None + self.fetch_thread = None + self.url_changed = threading.Event() + + # Add manifest info + self.target_duration = 10.0 # Default, will be updated from manifest + self.manifest_version = 3 # Default, will be updated from manifest + + logging.info(f"Initialized stream manager for channel {channel_id}") + + def update_url(self, new_url: str) -> bool: + """ + Handle stream URL changes with proper discontinuity marking. + + Args: + new_url: New stream URL to switch to + + Returns: + bool: True if URL changed, False if unchanged + + Side effects: + - Sets switching_stream flag + - Updates current_url + - Maintains sequence numbering + - Marks discontinuity point + - Signals fetch thread + """ + if new_url != self.current_url: + with buffer_lock: + self.switching_stream = True + self.current_url = new_url + + # Continue sequence numbering from last sequence + if self.buffered_sequences: + self.next_sequence = max(self.buffered_sequences) + 1 + + # Mark discontinuity at next sequence + self.source_changes.add(self.next_sequence) + + logging.info(f"Stream switch - next sequence will start at {self.next_sequence}") + + # Clear state but maintain sequence numbers + self.downloaded_sources.clear() + self.segment_durations.clear() + self.current_source = None + + # Signal thread to switch URL + self.url_changed.set() + + return True + return False + + def get_next_sequence(self, source_id: str) -> Optional[int]: + """ + Assign sequence numbers to segments with source change detection. + + Args: + source_id: Unique identifier for segment source + + Returns: + int: Next available sequence number + None: If segment already downloaded + + Side effects: + - Updates buffered sequences set + - Tracks source changes for discontinuity + - Maintains sequence numbering + """ + if source_id in self.downloaded_sources: + return None + + seq = self.next_sequence + while (seq in self.buffered_sequences): + seq += 1 + + # Track source changes for discontinuity markers + source_prefix = source_id.split('_')[0] + if not self.switching_stream and self.current_source and self.current_source != source_prefix: + self.source_changes.add(seq) + logging.debug(f"Source change detected at sequence {seq}") + self.current_source = source_prefix + + # Update tracking + self.downloaded_sources[source_id] = seq + self.buffered_sequences.add(seq) + self.next_sequence = seq + 1 + self.highest_sequence = max(self.highest_sequence, seq) + + return seq + + def _fetch_loop(self): + """Background thread for continuous stream fetching""" + while self.running: + try: + fetcher = StreamFetcher(self, self.buffer) + fetch_stream(fetcher, self.url_changed, self.next_sequence) + except Exception as e: + logging.error(f"Stream error: {e}") + time.sleep(5) # Wait before retry + + self.url_changed.clear() + + def start(self): + """Start the background fetch thread""" + if not self.fetch_thread or not self.fetch_thread.is_alive(): + self.running = True + self.fetch_thread = threading.Thread( + target=self._fetch_loop, + name="StreamFetcher", + daemon=True # Thread will exit when main program does + ) + self.fetch_thread.start() + logging.info("Stream manager started") + + def stop(self): + """Stop the background fetch thread""" + self.running = False + if self.fetch_thread and self.fetch_thread.is_alive(): + self.url_changed.set() # Signal thread to exit + self.fetch_thread.join(timeout=5) # Wait up to 5 seconds + logging.info("Stream manager stopped") class StreamFetcher: - """Handles HTTP requests for stream segments with connection pooling""" - def __init__(self, stream_url): - self.stream_url = stream_url + """ + Handles HTTP requests for stream segments with connection pooling. + + Attributes: + manager (StreamManager): Associated stream manager instance + buffer (StreamBuffer): Buffer for storing segments + session (requests.Session): Persistent HTTP session + redirect_cache (dict): Cache for redirect responses + + Features: + - Connection pooling and reuse + - Redirect caching + - Rate limiting + - Automatic retries + - Host fallback + """ + def __init__(self, manager: StreamManager, buffer: StreamBuffer): + self.manager = manager + self.buffer = buffer + self.stream_url = manager.current_url self.session = requests.Session() + # Configure session headers + self.session.headers.update({ + 'User-Agent': manager.user_agent, + 'Connection': 'keep-alive' + }) + # Set up connection pooling adapter = requests.adapters.HTTPAdapter( pool_connections=2, # Number of connection pools @@ -57,9 +323,26 @@ class StreamFetcher: self.min_request_interval = 0.05 # Minimum time between requests self.last_host = None # Cache last successful host self.redirect_cache = {} # Cache redirect responses + self.redirect_cache_limit = 1000 + + def cleanup_redirect_cache(self): + """Remove old redirect cache entries""" + if len(self.redirect_cache) > self.redirect_cache_limit: + self.redirect_cache.clear() - def get_base_host(self, url): - """Extract base host from URL using urlparse""" + def get_base_host(self, url: str) -> str: + """ + Extract base host from URL. + + Args: + url: Full URL to parse + + Returns: + str: Base host in format 'scheme://hostname' + + Example: + 'http://example.com/path' -> 'http://example.com' + """ try: parsed = urlparse(url) return f"{parsed.scheme}://{parsed.netloc}" @@ -67,8 +350,25 @@ class StreamFetcher: logging.error(f"Error extracting base host: {e}") return url - def download(self, url): - """Download content with connection reuse""" + def download(self, url: str) -> tuple[bytes, str]: + """ + Download content with connection reuse and redirect handling. + + Args: + url: URL to download from + + Returns: + tuple containing: + bytes: Downloaded content + str: Final URL after any redirects + + Features: + - Connection pooling/reuse + - Redirect caching + - Rate limiting + - Host fallback on failure + - Automatic retries + """ now = time.time() wait_time = self.last_request_time + self.min_request_interval - now if (wait_time > 0): @@ -102,69 +402,114 @@ class StreamFetcher: return self.download(new_url) raise -def analyze_ts_packet(data: bytes) -> dict: - """ - Analyze a single MPEG-TS packet (188 bytes) - - Args: - data: Raw packet bytes + def fetch_loop(self): + """ + Main fetching loop that continuously downloads stream content. - Returns: - dict with packet analysis: - - sync_valid: True if sync byte is 0x47 - - transport_error: True if TEI bit is set - - payload_start: True if PUSI bit is set - - pid: Packet ID (13 bits) - - hex_dump: First 16 bytes for debugging - """ - # Validate minimum packet size - if len(data) < 188: - return { - 'sync_valid': False, - 'transport_error': True, - 'payload_start': False, - 'pid': 0, - 'error': 'Packet too short' - } - - # Verify sync byte (0x47) - if data[0] != 0x47: - return { - 'sync_valid': False, - 'transport_error': True, - 'payload_start': False, - 'pid': 0, - 'error': 'Invalid sync byte' - } - - # Extract packet header fields - transport_error = (data[1] & 0x80) != 0 # Transport Error Indicator - payload_start = (data[1] & 0x40) != 0 # Payload Unit Start Indicator - pid = ((data[1] & 0x1F) << 8) | data[2] # Packet ID (13 bits) - - # Create hex dump for debugging - hex_dump = ' '.join(f'{b:02x}' for b in data[:16]) - - return { - 'sync_valid': True, - 'transport_error': transport_error, - 'payload_start': payload_start, - 'pid': pid, - 'hex_dump': hex_dump, - 'packet_size': len(data) - } + Features: + - Automatic manifest updates + - Rate-limited downloads + - Exponential backoff on errors + - Stream switch handling + - Segment validation + + Error handling: + - HTTP 509 rate limiting + - Connection drops + - Invalid segments + - Manifest parsing errors + + Thread safety: + Coordinates with StreamManager and StreamBuffer + using proper locking mechanisms + """ + retry_delay = 1 + max_retry_delay = 8 + last_manifest_time = 0 + + while self.manager.running: + try: + now = time.time() + + # Get manifest data + try: + manifest_data, final_url = self.download(self.manager.current_url) + manifest = m3u8.loads(manifest_data.decode()) + + # Reset retry delay on successful fetch + retry_delay = 1 + + except requests.exceptions.HTTPError as e: + if e.response.status_code == 509: + logging.warning("Rate limit exceeded, backing off...") + time.sleep(retry_delay) + retry_delay = min(retry_delay * 2, max_retry_delay) + continue + raise + + # Update manifest info + if manifest.target_duration: + self.manager.target_duration = float(manifest.target_duration) + if manifest.version: + self.manager.manifest_version = manifest.version + + if not manifest.segments: + logging.warning("No segments in manifest") + time.sleep(retry_delay) + continue + + # Calculate proper manifest polling interval + target_duration = float(manifest.target_duration) + manifest_interval = target_duration * 0.5 # Poll at half the segment duration + + # Process latest segment + latest_segment = manifest.segments[-1] + try: + segment_url = urljoin(final_url, latest_segment.uri) + segment_data, _ = self.download(segment_url) + + verification = verify_segment(segment_data) + if not verification.get('valid', False): + logging.warning(f"Invalid segment: {verification.get('error')}") + continue + + # Store segment with proper locking + with self.buffer.lock: + seq = self.manager.next_sequence + self.buffer[seq] = segment_data + self.manager.segment_durations[seq] = float(latest_segment.duration) + self.manager.next_sequence += 1 + logging.debug(f"Stored segment {seq} (duration: {latest_segment.duration}s)") + + except Exception as e: + logging.error(f"Segment download error: {e}") + continue + + # Update last manifest time and wait for next interval + last_manifest_time = now + time.sleep(manifest_interval) + + except Exception as e: + logging.error(f"Fetch error: {e}") + time.sleep(retry_delay) + retry_delay = min(retry_delay * 2, max_retry_delay) def get_segment_sequence(segment_uri: str) -> Optional[int]: """ - Extract sequence number from segment URI + Extract sequence number from segment URI pattern. Args: segment_uri: Segment filename or path Returns: - int: Sequence number if found - None: If no sequence can be extracted + int: Extracted sequence number if found + None: If no valid sequence number can be extracted + + Handles common patterns like: + - Numerical sequences (e.g., segment_1234.ts) + - Complex patterns with stream IDs (e.g., stream_123_456.ts) """ + try: # Try numerical sequence (e.g., 1038_3693.ts) if '_' in segment_uri: @@ -176,18 +521,25 @@ def get_segment_sequence(segment_uri: str) -> Optional[int]: # Update verify_segment with more thorough checks def verify_segment(data: bytes) -> dict: """ - Verify MPEG-TS segment integrity + Verify MPEG-TS segment integrity and structure. Args: - data: Raw segment bytes + data: Raw segment data bytes Returns: - dict with verification results: - - valid: True if segment passes all checks - - packets: Number of valid packets found - - size: Total segment size in bytes - - error: Description if validation fails + dict containing: + valid (bool): True if segment passes all checks + packets (int): Number of valid packets found + size (int): Total segment size in bytes + error (str): Description if validation fails + + Checks: + - Minimum size requirements + - Packet size alignment + - Sync byte presence + - Transport error indicators """ + # Check minimum size if len(data) < 188: return {'valid': False, 'error': 'Segment too short'} @@ -223,120 +575,22 @@ def verify_segment(data: bytes) -> dict: 'size': len(data) } -class StreamManager: - """Manages HLS stream state and switching logic""" - def __init__(self, initial_url: str): - # Stream state - self.current_url = initial_url - self.running = True - self.switching_stream = False - - # Sequence tracking - self.next_sequence = 0 - self.highest_sequence = 0 - self.buffered_sequences = set() - self.downloaded_sources = {} - self.segment_durations = {} - - # Source tracking - self.current_source = None - self.source_changes = set() - self.stream_switch_count = 0 - - # Threading - self.fetcher = None - self.fetch_thread = None - self.url_changed = threading.Event() - - def update_url(self, new_url: str) -> bool: - """Handle stream URL changes with proper discontinuity marking""" - if new_url != self.current_url: - with buffer_lock: - self.switching_stream = True - self.current_url = new_url - - # Set sequence numbers for stream switch - if segment_buffers: - self.highest_sequence = max(segment_buffers.keys()) - self.next_sequence = self.highest_sequence + 1 - # Mark discontinuity at first segment of new stream - self.source_changes = {self.next_sequence + 1} - else: - self.stream_switch_count += 1 - self.next_sequence = self.stream_switch_count * 1000 - self.source_changes = {self.next_sequence} - - logging.info(f"Stream switch - next sequence will start at {self.next_sequence}") - - # Clear state but maintain sequence numbers - self.downloaded_sources.clear() - self.segment_durations.clear() - self.current_source = None - - # Signal thread to switch URL - self.url_changed.set() - - return True - return False - - def get_next_sequence(self, source_id): - """Assign sequence numbers to segments with source change detection""" - if source_id in self.downloaded_sources: - return None - - seq = self.next_sequence - while seq in self.buffered_sequences: - seq += 1 - - # Track source changes for discontinuity markers - source_prefix = source_id.split('_')[0] - if not self.switching_stream and self.current_source and self.current_source != source_prefix: - self.source_changes.add(seq) - logging.debug(f"Source change detected at sequence {seq}") - self.current_source = source_prefix - - # Update tracking - self.downloaded_sources[source_id] = seq - self.buffered_sequences.add(seq) - self.next_sequence = seq + 1 - self.highest_sequence = max(self.highest_sequence, seq) - - return seq - - def _fetch_loop(self): - """Background thread for continuous stream fetching""" - while self.running: - try: - self.fetcher = StreamFetcher(self.current_url) - fetch_stream(self.fetcher, self.url_changed, self.next_sequence) - except Exception as e: - logging.error(f"Stream error: {e}") - time.sleep(5) # Wait before retry - - self.url_changed.clear() - - def start(self): - """Start the background fetch thread""" - if not self.fetch_thread or not self.fetch_thread.is_alive(): - self.running = True - self.fetch_thread = threading.Thread( - target=self._fetch_loop, - name="StreamFetcher", - daemon=True # Thread will exit when main program does - ) - self.fetch_thread.start() - logging.info("Stream manager started") - - def stop(self): - """Stop the background fetch thread""" - self.running = False - if self.fetch_thread and self.fetch_thread.is_alive(): - self.url_changed.set() # Signal thread to exit - self.fetch_thread.join(timeout=5) # Wait up to 5 seconds - logging.info("Stream manager stopped") - def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequence: int = 0): - global manifest_buffer, segment_buffers + """ + Main streaming function that handles manifest updates and segment downloads. + + Args: + fetcher: StreamFetcher instance to handle HTTP requests + stop_event: Threading event to signal when to stop fetching + start_sequence: Initial sequence number to start from + + The function implements the core HLS fetching logic: + - Fetches and parses manifest files + - Downloads new segments when they become available + - Handles stream switches with proper discontinuity marking + - Maintains buffer state and segment sequence numbering + """ + # Remove global stream_manager reference retry_delay = 1 max_retry_delay = 8 last_segment_time = 0 @@ -366,22 +620,22 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ manifest_content = manifest_data.decode() new_segments = {} - if stream_manager.switching_stream: + if fetcher.manager.switching_stream: # Use fetcher.manager instead of stream_manager # Stream switch - only get latest segment manifest_segments = [manifest.segments[-1]] - seq_start = stream_manager.next_sequence + seq_start = fetcher.manager.next_sequence max_segments = 1 logging.debug(f"Processing stream switch - getting latest segment at sequence {seq_start}") elif not buffer_initialized: # Initial buffer manifest_segments = manifest.segments[-Config.INITIAL_SEGMENTS:] - seq_start = stream_manager.next_sequence + seq_start = fetcher.manager.next_sequence max_segments = Config.INITIAL_SEGMENTS logging.debug(f"Starting initial buffer at sequence {seq_start}") else: # Normal operation manifest_segments = [manifest.segments[-1]] - seq_start = stream_manager.next_sequence + seq_start = fetcher.manager.next_sequence max_segments = 1 # Map segments @@ -391,7 +645,7 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ break source_id = segment.uri.split('/')[-1].split('.')[0] - next_seq = stream_manager.get_next_sequence(source_id) + next_seq = fetcher.manager.get_next_sequence(source_id) if next_seq is not None: duration = float(segment.duration) @@ -400,7 +654,7 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ 'duration': duration, 'source_id': source_id } - stream_manager.segment_durations[next_seq] = duration + fetcher.manager.segment_durations[next_seq] = duration segments_mapped += 1 manifest_buffer = manifest_content @@ -420,8 +674,8 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ segment_buffers[sequence_id] = segment_data logging.debug(f"Downloaded and verified segment {sequence_id} (packets: {validation['packets']})") - if stream_manager.switching_stream: - stream_manager.switching_stream = False + if fetcher.manager.switching_stream: + fetcher.manager.switching_stream = False stop_event.set() # Force fetcher restart with new URL break elif not buffer_initialized and len(segment_buffers) >= Config.INITIAL_SEGMENTS: @@ -442,118 +696,23 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ retry_delay = min(retry_delay * 2, max_retry_delay) manifest_update_needed = True -# Flask Routes for HLS Proxy Server -@app.route('/stream.m3u8') -def master_playlist(): - """ - Serve the HLS master playlist - - Handles: - - Initial buffering state - - Ongoing playback with sliding window - - Discontinuity markers for stream switches - - Dynamic segment durations - """ - with buffer_lock: - # Verify buffer state - if not manifest_buffer or not segment_buffers: - logging.warning("No manifest or segments available yet") - return '', 404 - - available = sorted(segment_buffers.keys()) - if not available: - logging.warning("No segments available") - return '', 404 - - manifest = m3u8.loads(manifest_buffer) - max_seq = max(available) - - # Calculate window bounds - if len(available) <= Config.INITIAL_SEGMENTS: - # During initial buffering, show all segments - min_seq = min(available) - else: - # For ongoing playback, maintain sliding window - min_seq = max( - min(available), - max_seq - Config.WINDOW_SIZE + 1 - ) - - # Build manifest with proper tags - new_manifest = ['#EXTM3U'] - new_manifest.append('#EXT-X-VERSION:3') - new_manifest.append(f'#EXT-X-MEDIA-SEQUENCE:{min_seq}') - new_manifest.append(f'#EXT-X-TARGETDURATION:{int(manifest.target_duration)}') - - # Filter segments within window - window_segments = [s for s in available if min_seq <= s <= max_seq] - - # Add segments with discontinuity handling - for seq in window_segments: - # Mark stream switches with discontinuity - if seq in stream_manager.source_changes: - new_manifest.append('#EXT-X-DISCONTINUITY') - logging.debug(f"Added discontinuity marker before segment {seq}") - - # Use actual segment duration or fallback to target - duration = stream_manager.segment_durations.get(seq, manifest.target_duration) - new_manifest.append(f'#EXTINF:{duration:.3f},') - new_manifest.append(f'/segments/{seq}.ts') - - manifest_content = '\n'.join(new_manifest) - logging.debug(f"Serving manifest with segments {min_seq}-{max_seq} (window: {len(window_segments)})") - return Response(manifest_content, content_type='application/vnd.apple.mpegurl') -@app.route('/segments/') -def get_segment(segment_name): - """ - Serve individual MPEG-TS segments - - Args: - segment_name: Segment filename (e.g., '123.ts') - - Returns: - MPEG-TS segment data or 404 if not found - """ - try: - segment_id = int(segment_name.split('.')[0]) - with buffer_lock: - if segment_id in segment_buffers: - available = sorted(segment_buffers.keys()) - logging.debug(f"Client requested segment {segment_id} (buffer: {min(available)}-{max(available)})") - return Response(segment_buffers[segment_id], content_type='video/MP2T') - - logging.warning( - f"Segment {segment_id} not found. " - f"Available: {min(segment_buffers.keys()) if segment_buffers else 'none'}" - f"-{max(segment_buffers.keys()) if segment_buffers else 'none'}" - ) - except Exception as e: - logging.error(f"Error serving segment {segment_name}: {e}") - return '', 404 - -@app.route('/change_stream', methods=['POST']) -def change_stream(): - """ - Handle stream URL changes via HTTP POST - - Expected JSON body: - {"url": "new_stream_url"} - - Returns: - JSON response indicating success/failure - """ - new_url = request.json.get('url') - if not new_url: - return jsonify({'error': 'No URL provided'}), 400 - - if stream_manager.update_url(new_url): - return jsonify({'message': 'Stream URL updated', 'url': new_url}) - return jsonify({'message': 'URL unchanged', 'url': new_url}) @app.before_request def log_request_info(): - """Log client connections and important requests""" + """ + Log client connections and important requests. + + Logs: + INFO: + - First manifest request from new client + - Stream switch requests + DEBUG: + - All other requests + + Format: + {client_ip} - {method} {path} + """ if request.path == '/stream.m3u8' and not segment_buffers: # First manifest request from a client logging.info(f"New client connected from {request.remote_addr}") @@ -567,15 +726,317 @@ def log_request_info(): # Configure Werkzeug logger to DEBUG logging.getLogger('werkzeug').setLevel(logging.DEBUG) +class ProxyServer: + """Manages HLS proxy server instance""" + + def __init__(self, user_agent: Optional[str] = None): + self.app = Flask(__name__) + self.stream_managers: Dict[str, StreamManager] = {} + self.stream_buffers: Dict[str, StreamBuffer] = {} + self.client_managers: Dict[str, ClientManager] = {} + self.fetch_threads: Dict[str, threading.Thread] = {} + self.user_agent: str = user_agent or Config.DEFAULT_USER_AGENT + self._setup_routes() + + def _setup_routes(self) -> None: + """Configure Flask routes""" + self.app.add_url_rule( + '/stream/', # Changed from //stream.m3u8 + view_func=self.stream_endpoint + ) + self.app.add_url_rule( + '/stream//segments/', # Updated to match new pattern + view_func=self.get_segment + ) + self.app.add_url_rule( + '/change_stream/', # Changed from //change_stream + view_func=self.change_stream, + methods=['POST'] + ) + + def initialize_channel(self, url: str, channel_id: str) -> None: + """Initialize a new channel stream""" + if channel_id in self.stream_managers: + self.stop_channel(channel_id) + + manager = StreamManager( + url, + channel_id, + user_agent=self.user_agent + ) + buffer = StreamBuffer() + + # Store resources + self.stream_managers[channel_id] = manager + self.stream_buffers[channel_id] = buffer + self.client_managers[channel_id] = ClientManager() + + # Create and store fetcher + fetcher = StreamFetcher(manager, buffer) + manager.fetcher = fetcher # Store reference to fetcher + + # Start fetch thread + self.fetch_threads[channel_id] = threading.Thread( + target=fetcher.fetch_loop, + name=f"StreamFetcher-{channel_id}", + daemon=True + ) + self.fetch_threads[channel_id].start() + + logging.info(f"Initialized channel {channel_id} with URL {url}") + + def stop_channel(self, channel_id: str) -> None: + """Stop and cleanup a channel""" + if channel_id in self.stream_managers: + self.stream_managers[channel_id].stop() + if channel_id in self.fetch_threads: + self.fetch_threads[channel_id].join(timeout=5) + self._cleanup_channel(channel_id) + + def _cleanup_channel(self, channel_id: str) -> None: + """ + Remove all resources associated with a channel. + + Args: + channel_id: Channel to cleanup + + Removes: + - Stream manager instance + - Segment buffer + - Client manager + - Fetch thread reference + + Thread safety: + Should only be called after stream manager is stopped + and fetch thread has completed + """ + + for collection in [self.stream_managers, self.stream_buffers, + self.client_managers, self.fetch_threads]: + collection.pop(channel_id, None) + + def run(self, host: str = '0.0.0.0', port: int = 5000) -> None: + """Start the proxy server""" + try: + self.app.run(host=host, port=port, threaded=True) + except KeyboardInterrupt: + logging.info("Shutting down gracefully...") + self.shutdown() + except Exception as e: + logging.error(f"Server error: {e}") + self.shutdown() + raise + + def shutdown(self) -> None: + """ + Stop all channels and cleanup resources. + + Steps: + 1. Stop all active stream managers + 2. Join fetch threads + 3. Clean up channel resources + 4. Release system resources + + Thread Safety: + Safe to call from signal handlers or during shutdown + """ + for channel_id in list(self.stream_managers.keys()): + self.stop_channel(channel_id) + + def stream_endpoint(self, channel_id: str): + """ + Flask route handler for serving HLS manifests. + + Args: + channel_id: Unique identifier for the stream channel + + Returns: + Flask Response with: + - HLS manifest content + - Proper content type + - 404 if channel/segments not found + + The manifest includes: + - Current segment window + - Proper sequence numbering + - Discontinuity markers + - Accurate segment durations + """ + try: + if channel_id not in self.stream_managers: + return Response('Channel not found', status=404) + + buffer = self.stream_buffers[channel_id] + manager = self.stream_managers[channel_id] + + # Verify buffer state + with buffer.lock: + available = sorted(buffer.keys()) + if not available: + logging.warning("No segments available") + return '', 404 + + max_seq = max(available) + + # Find the first segment after any discontinuity + discontinuity_start = min(available) + for seq in available: + if seq in manager.source_changes: + discontinuity_start = seq + break + + # Calculate window bounds starting from discontinuity + if len(available) <= Config.INITIAL_SEGMENTS: + min_seq = discontinuity_start + else: + min_seq = max( + discontinuity_start, + max_seq - Config.WINDOW_SIZE + 1 + ) + + # Build manifest with proper tags + new_manifest = ['#EXTM3U'] + new_manifest.append(f'#EXT-X-VERSION:{manager.manifest_version}') + new_manifest.append(f'#EXT-X-MEDIA-SEQUENCE:{min_seq}') + new_manifest.append(f'#EXT-X-TARGETDURATION:{int(manager.target_duration)}') + + # Filter segments within window + window_segments = [s for s in available if min_seq <= s <= max_seq] + + # Add segments with discontinuity handling + for seq in window_segments: + if seq in manager.source_changes: + new_manifest.append('#EXT-X-DISCONTINUITY') + logging.debug(f"Added discontinuity marker before segment {seq}") + + duration = manager.segment_durations.get(seq, 10.0) + new_manifest.append(f'#EXTINF:{duration},') + new_manifest.append(f'/stream/{channel_id}/segments/{seq}.ts') + + manifest_content = '\n'.join(new_manifest) + logging.debug(f"Serving manifest with segments {min_seq}-{max_seq} (window: {len(window_segments)})") + return Response(manifest_content, content_type='application/vnd.apple.mpegurl') + except ConnectionAbortedError: + logging.debug("Client disconnected") + return '', 499 + except Exception as e: + logging.error(f"Stream endpoint error: {e}") + return '', 500 + + def get_segment(self, channel_id: str, segment_name: str): + """ + Serve individual MPEG-TS segments to clients. + + Args: + channel_id: Unique identifier for the channel + segment_name: Segment filename (e.g., '123.ts') + + Returns: + Flask Response: + - MPEG-TS segment data with video/MP2T content type + - 404 if segment or channel not found + + Error Handling: + - Logs warning if segment not found + - Logs error on unexpected exceptions + - Returns 404 on any error + """ + if channel_id not in self.stream_managers: + return Response('Channel not found', status=404) + + try: + segment_id = int(segment_name.split('.')[0]) + buffer = self.stream_buffers[channel_id] + + with buffer_lock: + if segment_id in buffer: + return Response(buffer[segment_id], content_type='video/MP2T') + + logging.warning(f"Segment {segment_id} not found for channel {channel_id}") + except Exception as e: + logging.error(f"Error serving segment {segment_name}: {e}") + return '', 404 + + def change_stream(self, channel_id: str): + """ + Handle stream URL changes via POST request. + + Args: + channel_id: Channel to modify + + Expected JSON body: + { + "url": "new_stream_url" + } + + Returns: + JSON response with: + - Success/failure message + - Channel ID + - New/current URL + - HTTP 404 if channel not found + - HTTP 400 if URL missing from request + + Side effects: + - Updates stream manager URL + - Triggers stream switch sequence + - Maintains segment numbering + """ + if channel_id not in self.stream_managers: + return jsonify({'error': 'Channel not found'}), 404 + + new_url = request.json.get('url') + if not new_url: + return jsonify({'error': 'No URL provided'}), 400 + + manager = self.stream_managers[channel_id] + if manager.update_url(new_url): + return jsonify({ + 'message': 'Stream URL updated', + 'channel': channel_id, + 'url': new_url + }) + return jsonify({ + 'message': 'URL unchanged', + 'channel': channel_id, + 'url': new_url + }) + + @app.before_request + def log_request_info(): + """ + Log client connections and important requests. + + Log Levels: + INFO: + - First manifest request from new client + - Stream switch requests + DEBUG: + - Segment requests + - Routine manifest updates + + Format: + "{client_ip} - {method} {path}" + + Side Effects: + - Updates logging configuration based on request type + - Tracks client connections + """ + # Main Application Setup if __name__ == '__main__': # Command line argument parsing parser = argparse.ArgumentParser(description='HLS Proxy Server with Stream Switching') parser.add_argument( '--url', '-u', - default='http://example.com/stream.m3u8', + required=True, help='Initial HLS stream URL to proxy' ) + parser.add_argument( + '--channel', '-c', + required=True, + help='Channel ID for the stream (default: default)' + ) parser.add_argument( '--port', '-p', type=int, @@ -587,6 +1048,10 @@ if __name__ == '__main__': default='0.0.0.0', help='Interface to bind server to (default: all interfaces)' ) + parser.add_argument( + '--user-agent', '-ua', + help='Custom User-Agent string' + ) parser.add_argument( '--debug', action='store_true', @@ -594,34 +1059,30 @@ if __name__ == '__main__': ) args = parser.parse_args() - # Configure logging with separate format for access logs + # Configure logging logging.basicConfig( level=logging.DEBUG if args.debug else logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) - # Initialize proxy components try: - # Create and start stream manager - stream_manager = StreamManager(args.url) - stream_manager.start() + # Initialize proxy server + proxy = ProxyServer(user_agent=args.user_agent) + + # Initialize channel with provided URL + proxy.initialize_channel(args.url, args.channel) logging.info(f"Starting HLS proxy server on {args.host}:{args.port}") logging.info(f"Initial stream URL: {args.url}") + logging.info(f"Channel ID: {args.channel}") # Run Flask development server - # Note: For production, use a proper WSGI server like gunicorn - app.run( - host=args.host, - port=args.port, - threaded=True, # Enable multi-threading for segment handling - debug=args.debug - ) + proxy.run(host=args.host, port=args.port) + except Exception as e: logging.error(f"Failed to start server: {e}") - if stream_manager: - stream_manager.running = False - if stream_manager.fetch_thread: - stream_manager.fetch_thread.join() sys.exit(1) + finally: + if 'proxy' in locals(): + proxy.shutdown() From cc6442a96aad12613200f18821ae460ee0221009 Mon Sep 17 00:00:00 2001 From: SergeantPanda Date: Fri, 28 Feb 2025 16:52:50 -0600 Subject: [PATCH 04/14] Added client connection tracking. Still need to work on buffer performance a bit. --- apps/proxy/hls_proxy | 201 ++++++++++++++++++++++++++++++------------- 1 file changed, 141 insertions(+), 60 deletions(-) diff --git a/apps/proxy/hls_proxy b/apps/proxy/hls_proxy index 380e464d..27127787 100644 --- a/apps/proxy/hls_proxy +++ b/apps/proxy/hls_proxy @@ -35,6 +35,10 @@ class Config: WINDOW_SIZE = 12 # Number of segments in manifest window INITIAL_SEGMENTS = 3 # Initial segments to buffer before playback DEFAULT_USER_AGENT = 'VLC/3.0.20 LibVLC/3.0.20' + INITIAL_CONNECTION_WINDOW = 10 # Seconds to wait for first client + CLIENT_TIMEOUT_FACTOR = 1.5 # Multiplier for target duration to determine client timeout + CLIENT_CLEANUP_INTERVAL = 10 # Seconds between client cleanup checks + FIRST_SEGMENT_TIMEOUT = 5.0 # Seconds to wait for first segment class StreamBuffer: """ @@ -84,36 +88,45 @@ class StreamBuffer: del self.buffer[seq] class ClientManager: - """Manages client connections""" + """Manages client connections and activity tracking""" def __init__(self): - """ - Initialize client connection tracking. + self.last_activity = {} # Maps client IPs to last activity timestamp + self.lock = threading.Lock() - Attributes: - active_clients (Set[int]): Set of active client IDs - lock (threading.Lock): Thread safety for client operations - """ - self.active_clients: Set[int] = set() - self.lock: threading.Lock = threading.Lock() - - def add_client(self, client_id: int): - """Add new client connection""" + def record_activity(self, client_ip: str): + """Record client activity timestamp""" with self.lock: - self.active_clients.add(client_id) - logging.info(f"New client connected (total: {len(self.active_clients)})") - - def remove_client(self, client_id: int) -> int: - """Remove client and return remaining count""" + prev_time = self.last_activity.get(client_ip) + current_time = time.time() + self.last_activity[client_ip] = current_time + if not prev_time: + logging.info(f"New client connected: {client_ip}") + else: + logging.debug(f"Client activity: {client_ip}") + + def cleanup_inactive(self, timeout: float) -> bool: + """Remove inactive clients""" + now = time.time() with self.lock: - self.active_clients.remove(client_id) - remaining = len(self.active_clients) - logging.info(f"Client disconnected (remaining: {remaining})") - return remaining - - def has_clients(self) -> bool: - """Check if any clients are connected""" - return len(self.active_clients) > 0 + active_clients = { + ip: last_time + for ip, last_time in self.last_activity.items() + if (now - last_time) < timeout + } + + removed = set(self.last_activity.keys()) - set(active_clients.keys()) + if removed: + for ip in removed: + inactive_time = now - self.last_activity[ip] + logging.warning(f"Client {ip} inactive for {inactive_time:.1f}s, removing") + + self.last_activity = active_clients + if active_clients: + oldest = min(now - t for t in active_clients.values()) + logging.debug(f"Active clients: {len(active_clients)}, oldest activity: {oldest:.1f}s ago") + + return len(active_clients) == 0 class StreamManager: """ @@ -163,6 +176,19 @@ class StreamManager: self.target_duration = 10.0 # Default, will be updated from manifest self.manifest_version = 3 # Default, will be updated from manifest + self.cleanup_thread = None + self.cleanup_running = False # New flag to control cleanup thread + self.cleanup_enabled = False # New flag to control when cleanup starts + self.initialization_time = time.time() # Add initialization timestamp + self.first_client_connected = False + self.cleanup_started = False # New flag to track cleanup state + + # Add client manager reference + self.client_manager = None + self.proxy_server = None # Reference to proxy server for cleanup + self.cleanup_thread = None + self.cleanup_interval = Config.CLIENT_CLEANUP_INTERVAL + logging.info(f"Initialized stream manager for channel {channel_id}") def update_url(self, new_url: str) -> bool: @@ -270,12 +296,59 @@ class StreamManager: logging.info("Stream manager started") def stop(self): - """Stop the background fetch thread""" + """Stop the stream manager and cleanup resources""" self.running = False + self.cleanup_running = False if self.fetch_thread and self.fetch_thread.is_alive(): - self.url_changed.set() # Signal thread to exit - self.fetch_thread.join(timeout=5) # Wait up to 5 seconds - logging.info("Stream manager stopped") + self.url_changed.set() + self.fetch_thread.join(timeout=5) + logging.info(f"Stream manager stopped for channel {self.channel_id}") + + def enable_cleanup(self): + """Enable cleanup after first client connects""" + if not self.first_client_connected: + self.first_client_connected = True + logging.info(f"First client connected to channel {self.channel_id}") + + def start_cleanup_thread(self): + """Start background thread for client activity monitoring""" + def cleanup_loop(): + # Wait for initial connection window + start_time = time.time() + while self.cleanup_running and (time.time() - start_time) < Config.INITIAL_CONNECTION_WINDOW: + if self.first_client_connected: + break + time.sleep(1) + + if not self.first_client_connected: + logging.info(f"Channel {self.channel_id}: No clients connected within {Config.INITIAL_CONNECTION_WINDOW}s window") + self.proxy_server.stop_channel(self.channel_id) + return + + # Normal client activity monitoring + while self.cleanup_running and self.running: + try: + timeout = self.target_duration * Config.CLIENT_TIMEOUT_FACTOR + if self.client_manager.cleanup_inactive(timeout): + logging.info(f"Channel {self.channel_id}: All clients disconnected for {timeout:.1f}s") + self.proxy_server.stop_channel(self.channel_id) + break + except Exception as e: + logging.error(f"Cleanup error: {e}") + if "cannot join current thread" not in str(e): + time.sleep(Config.CLIENT_CLEANUP_INTERVAL) + time.sleep(Config.CLIENT_CLEANUP_INTERVAL) + + if not self.cleanup_started: + self.cleanup_started = True + self.cleanup_running = True + self.cleanup_thread = threading.Thread( + target=cleanup_loop, + name=f"Cleanup-{self.channel_id}", + daemon=True + ) + self.cleanup_thread.start() + logging.info(f"Started cleanup thread for channel {self.channel_id}") class StreamFetcher: """ @@ -765,15 +838,20 @@ class ProxyServer: user_agent=self.user_agent ) buffer = StreamBuffer() + client_manager = ClientManager() + + # Set up references + manager.client_manager = client_manager + manager.proxy_server = self # Store resources self.stream_managers[channel_id] = manager self.stream_buffers[channel_id] = buffer - self.client_managers[channel_id] = ClientManager() + self.client_managers[channel_id] = client_manager # Create and store fetcher fetcher = StreamFetcher(manager, buffer) - manager.fetcher = fetcher # Store reference to fetcher + manager.fetcher = fetcher # Start fetch thread self.fetch_threads[channel_id] = threading.Thread( @@ -783,6 +861,9 @@ class ProxyServer: ) self.fetch_threads[channel_id].start() + # Start cleanup monitoring immediately + manager.start_cleanup_thread() + logging.info(f"Initialized channel {channel_id} with URL {url}") def stop_channel(self, channel_id: str) -> None: @@ -844,40 +925,36 @@ class ProxyServer: self.stop_channel(channel_id) def stream_endpoint(self, channel_id: str): - """ - Flask route handler for serving HLS manifests. - - Args: - channel_id: Unique identifier for the stream channel - - Returns: - Flask Response with: - - HLS manifest content - - Proper content type - - 404 if channel/segments not found - - The manifest includes: - - Current segment window - - Proper sequence numbering - - Discontinuity markers - - Accurate segment durations - """ + """Flask route handler for serving HLS manifests.""" try: - if channel_id not in self.stream_managers: + if (channel_id not in self.stream_managers) or (not self.stream_managers[channel_id].running): return Response('Channel not found', status=404) - - buffer = self.stream_buffers[channel_id] - manager = self.stream_managers[channel_id] - # Verify buffer state + manager = self.stream_managers[channel_id] + buffer = self.stream_buffers[channel_id] + + # Record client activity and enable cleanup + client_ip = request.remote_addr + manager.enable_cleanup() + self.client_managers[channel_id].record_activity(client_ip) + + # Wait for first segment with timeout + start_time = time.time() + while True: + with buffer.lock: + available = sorted(buffer.keys()) + if available: + break + + if time.time() - start_time > Config.FIRST_SEGMENT_TIMEOUT: + logging.warning(f"Timeout waiting for first segment for channel {channel_id}") + return Response('No segments available', status=503) + + time.sleep(0.1) # Short sleep to prevent CPU spinning + + # Rest of manifest generation code... with buffer.lock: - available = sorted(buffer.keys()) - if not available: - logging.warning("No segments available") - return '', 404 - max_seq = max(available) - # Find the first segment after any discontinuity discontinuity_start = min(available) for seq in available: @@ -945,6 +1022,10 @@ class ProxyServer: return Response('Channel not found', status=404) try: + # Record client activity + client_ip = request.remote_addr + self.client_managers[channel_id].record_activity(client_ip) + segment_id = int(segment_name.split('.')[0]) buffer = self.stream_buffers[channel_id] From 28cb928e75df5033715abe5eb80df138c1a0fbb3 Mon Sep 17 00:00:00 2001 From: SergeantPanda Date: Sat, 1 Mar 2025 22:09:13 -0600 Subject: [PATCH 05/14] Updated HLS proxy to download more segments at start of streaming and pauses client until segments are ready. --- apps/proxy/hls_proxy | 172 +++++++++++++++++++++++++++++-------------- 1 file changed, 115 insertions(+), 57 deletions(-) diff --git a/apps/proxy/hls_proxy b/apps/proxy/hls_proxy index 27127787..98ba38a5 100644 --- a/apps/proxy/hls_proxy +++ b/apps/proxy/hls_proxy @@ -39,6 +39,11 @@ class Config: CLIENT_TIMEOUT_FACTOR = 1.5 # Multiplier for target duration to determine client timeout CLIENT_CLEANUP_INTERVAL = 10 # Seconds between client cleanup checks FIRST_SEGMENT_TIMEOUT = 5.0 # Seconds to wait for first segment + + # Initial buffering settings + INITIAL_BUFFER_SECONDS = 25.0 # Initial buffer in seconds before allowing clients + MAX_INITIAL_SEGMENTS = 10 # Maximum segments to fetch during initialization + BUFFER_READY_TIMEOUT = 30.0 # Maximum time to wait for initial buffer (seconds) class StreamBuffer: """ @@ -191,6 +196,11 @@ class StreamManager: logging.info(f"Initialized stream manager for channel {channel_id}") + # Buffer state tracking + self.buffer_ready = threading.Event() + self.buffered_duration = 0.0 + self.initial_buffering = True + def update_url(self, new_url: str) -> bool: """ Handle stream URL changes with proper discontinuity marking. @@ -476,91 +486,129 @@ class StreamFetcher: raise def fetch_loop(self): - """ - Main fetching loop that continuously downloads stream content. - - Features: - - Automatic manifest updates - - Rate-limited downloads - - Exponential backoff on errors - - Stream switch handling - - Segment validation - - Error handling: - - HTTP 509 rate limiting - - Connection drops - - Invalid segments - - Manifest parsing errors - - Thread safety: - Coordinates with StreamManager and StreamBuffer - using proper locking mechanisms - """ + """Main fetch loop for stream data""" retry_delay = 1 max_retry_delay = 8 last_manifest_time = 0 + downloaded_segments = set() # Track downloaded segment URIs while self.manager.running: try: - now = time.time() + current_time = time.time() - # Get manifest data - try: - manifest_data, final_url = self.download(self.manager.current_url) - manifest = m3u8.loads(manifest_data.decode()) - - # Reset retry delay on successful fetch - retry_delay = 1 - - except requests.exceptions.HTTPError as e: - if e.response.status_code == 509: - logging.warning("Rate limit exceeded, backing off...") - time.sleep(retry_delay) - retry_delay = min(retry_delay * 2, max_retry_delay) + # Check manifest update timing + if last_manifest_time: + time_since_last = current_time - last_manifest_time + if time_since_last < (self.manager.target_duration * 0.5): + time.sleep(self.manager.target_duration * 0.5 - time_since_last) continue - raise + + # Get manifest data + manifest_data, final_url = self.download(self.manager.current_url) + manifest = m3u8.loads(manifest_data.decode()) # Update manifest info if manifest.target_duration: self.manager.target_duration = float(manifest.target_duration) if manifest.version: self.manager.manifest_version = manifest.version - + if not manifest.segments: - logging.warning("No segments in manifest") - time.sleep(retry_delay) continue - # Calculate proper manifest polling interval - target_duration = float(manifest.target_duration) - manifest_interval = target_duration * 0.5 # Poll at half the segment duration - - # Process latest segment + if self.manager.initial_buffering: + segments_to_fetch = [] + current_duration = 0.0 + successful_downloads = 0 # Initialize counter here + + # Start from the end of the manifest + for segment in reversed(manifest.segments): + current_duration += float(segment.duration) + segments_to_fetch.append(segment) + + # Stop when we have enough duration or hit max segments + if (current_duration >= Config.INITIAL_BUFFER_SECONDS or + len(segments_to_fetch) >= Config.MAX_INITIAL_SEGMENTS): + break + + # Reverse back to chronological order + segments_to_fetch.reverse() + + # Download initial segments + for segment in segments_to_fetch: + try: + segment_url = urljoin(final_url, segment.uri) + segment_data, _ = self.download(segment_url) + + validation = verify_segment(segment_data) + if validation.get('valid', False): + with self.buffer.lock: + seq = self.manager.next_sequence + self.buffer[seq] = segment_data + duration = float(segment.duration) + self.manager.segment_durations[seq] = duration + self.manager.buffered_duration += duration + self.manager.next_sequence += 1 + successful_downloads += 1 + logging.debug(f"Buffered initial segment {seq} (source: {segment.uri}, duration: {duration}s)") + except Exception as e: + logging.error(f"Initial segment download error: {e}") + + # Only mark buffer ready if we got some segments + if successful_downloads > 0: + self.manager.initial_buffering = False + self.manager.buffer_ready.set() + logging.info(f"Initial buffer ready with {successful_downloads} segments " + f"({self.manager.buffered_duration:.1f}s of content)") + continue + + # Normal operation - get latest segment if we haven't already latest_segment = manifest.segments[-1] + if latest_segment.uri in downloaded_segments: + # Wait for next manifest update + time.sleep(self.manager.target_duration * 0.5) + continue + try: segment_url = urljoin(final_url, latest_segment.uri) segment_data, _ = self.download(segment_url) - verification = verify_segment(segment_data) - if not verification.get('valid', False): - logging.warning(f"Invalid segment: {verification.get('error')}") - continue + # Try several times if segment validation fails + max_retries = 3 + retry_count = 0 + while retry_count < max_retries: + verification = verify_segment(segment_data) + if verification.get('valid', False): + break + logging.warning(f"Invalid segment, retry {retry_count + 1}/{max_retries}: {verification.get('error')}") + time.sleep(0.5) # Short delay before retry + segment_data, _ = self.download(segment_url) + retry_count += 1 - # Store segment with proper locking - with self.buffer.lock: - seq = self.manager.next_sequence - self.buffer[seq] = segment_data - self.manager.segment_durations[seq] = float(latest_segment.duration) - self.manager.next_sequence += 1 - logging.debug(f"Stored segment {seq} (duration: {latest_segment.duration}s)") + if verification.get('valid', False): + with self.buffer.lock: + seq = self.manager.next_sequence + self.buffer[seq] = segment_data + self.manager.segment_durations[seq] = float(latest_segment.duration) + self.manager.next_sequence += 1 + downloaded_segments.add(latest_segment.uri) + logging.debug(f"Stored segment {seq} (source: {latest_segment.uri}, " + f"duration: {latest_segment.duration}s, " + f"size: {len(segment_data)})") + + # Update timing + last_manifest_time = time.time() + retry_delay = 1 # Reset retry delay on success + else: + logging.error(f"Segment validation failed after {max_retries} retries") except Exception as e: logging.error(f"Segment download error: {e}") continue - # Update last manifest time and wait for next interval - last_manifest_time = now - time.sleep(manifest_interval) + # Cleanup old segment URIs from tracking + if len(downloaded_segments) > 100: + downloaded_segments.clear() except Exception as e: logging.error(f"Fetch error: {e}") @@ -926,6 +974,16 @@ class ProxyServer: def stream_endpoint(self, channel_id: str): """Flask route handler for serving HLS manifests.""" + if channel_id not in self.stream_managers: + return Response('Channel not found', status=404) + + manager = self.stream_managers[channel_id] + + # Wait for initial buffer + if not manager.buffer_ready.wait(Config.BUFFER_READY_TIMEOUT): + logging.error(f"Timeout waiting for initial buffer for channel {channel_id}") + return Response('Initial buffer not ready', status=503) + try: if (channel_id not in self.stream_managers) or (not self.stream_managers[channel_id].running): return Response('Channel not found', status=404) From 1c08c71b476628ee490bfd1924ea704d6c5083a2 Mon Sep 17 00:00:00 2001 From: kappa118 Date: Sun, 2 Mar 2025 09:54:04 -0500 Subject: [PATCH 06/14] keeping alpine for now, primary is debian-based again --- docker/Dockerfile | 64 ++++++++++++++++++++++------------------ docker/Dockerfile.alpine | 54 +++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 28 deletions(-) create mode 100644 docker/Dockerfile.alpine diff --git a/docker/Dockerfile b/docker/Dockerfile index 0b25212d..2c3b97a0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,28 +1,35 @@ -FROM alpine +FROM python:3.13-slim ENV PATH="/dispatcharrpy/bin:$PATH" \ VIRTUAL_ENV=/dispatcharrpy \ DJANGO_SETTINGS_MODULE=dispatcharr.settings \ PYTHONUNBUFFERED=1 -RUN apk add \ - python3 \ - python3-dev \ - gcc \ - musl-dev \ - linux-headers \ - py3-pip \ +RUN apt-get update && \ + apt-get install -y \ + curl \ ffmpeg \ - streamlink \ - vlc \ - libpq-dev \ gcc \ - py3-virtualenv \ - uwsgi \ - uwsgi-python \ - nodejs \ - npm \ git \ + gpg \ + libpq-dev \ + lsb-release \ + python3-virtualenv \ + streamlink \ + uwsgi + +RUN \ + curl -sL https://deb.nodesource.com/setup_23.x -o /tmp/nodesource_setup.sh && \ + bash /tmp/nodesource_setup.sh && \ + curl -fsSL https://packages.redis.io/gpg | gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg && \ + chmod 644 /usr/share/keyrings/redis-archive-keyring.gpg && \ + echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/redis.list && \ + apt-get update && \ + apt-get install -y redis + +RUN apt-get update && \ + apt-get install -y \ + nodejs \ redis RUN \ @@ -30,24 +37,25 @@ RUN \ virtualenv /dispatcharrpy && \ git clone https://github.com/Dispatcharr/Dispatcharr /app && \ cd /app && \ - /dispatcharrpy/bin/pip install --no-cache-dir -r requirements.txt && \ + pip install --no-cache-dir -r requirements.txt && \ cd /app/frontend && \ npm install && \ npm run build && \ find . -maxdepth 1 ! -name '.' ! -name 'build' -exec rm -rf '{}' \; && \ cd /app && \ - python manage.py collectstatic --noinput || true - -# Cleanup -RUN \ - apk del \ - nodejs \ - npm \ - git \ + python manage.py collectstatic --noinput || true && \ + apt-get remove -y \ gcc \ - musl-dev \ - python3-dev \ - linux-headers + git \ + gpg \ + libpq-dev \ + lsb-release \ + nodejs && \ + apt-get clean && \ + rm -rf \ + /tmp/* \ + /var/lib/apt/lists/* \ + /var/tmp/* WORKDIR /app diff --git a/docker/Dockerfile.alpine b/docker/Dockerfile.alpine new file mode 100644 index 00000000..0b25212d --- /dev/null +++ b/docker/Dockerfile.alpine @@ -0,0 +1,54 @@ +FROM alpine + +ENV PATH="/dispatcharrpy/bin:$PATH" \ + VIRTUAL_ENV=/dispatcharrpy \ + DJANGO_SETTINGS_MODULE=dispatcharr.settings \ + PYTHONUNBUFFERED=1 + +RUN apk add \ + python3 \ + python3-dev \ + gcc \ + musl-dev \ + linux-headers \ + py3-pip \ + ffmpeg \ + streamlink \ + vlc \ + libpq-dev \ + gcc \ + py3-virtualenv \ + uwsgi \ + uwsgi-python \ + nodejs \ + npm \ + git \ + redis + +RUN \ + mkdir /data && \ + virtualenv /dispatcharrpy && \ + git clone https://github.com/Dispatcharr/Dispatcharr /app && \ + cd /app && \ + /dispatcharrpy/bin/pip install --no-cache-dir -r requirements.txt && \ + cd /app/frontend && \ + npm install && \ + npm run build && \ + find . -maxdepth 1 ! -name '.' ! -name 'build' -exec rm -rf '{}' \; && \ + cd /app && \ + python manage.py collectstatic --noinput || true + +# Cleanup +RUN \ + apk del \ + nodejs \ + npm \ + git \ + gcc \ + musl-dev \ + python3-dev \ + linux-headers + +WORKDIR /app + +CMD ["/app/docker/entrypoint.aio.sh"] From d6477cef5526b1cf28f049052b21b982e7b21e0c Mon Sep 17 00:00:00 2001 From: kappa118 Date: Sun, 2 Mar 2025 10:05:39 -0500 Subject: [PATCH 07/14] updated for debian-based package manager --- docker/entrypoint.aio.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/entrypoint.aio.sh b/docker/entrypoint.aio.sh index e8441cd4..f5b37172 100755 --- a/docker/entrypoint.aio.sh +++ b/docker/entrypoint.aio.sh @@ -4,17 +4,17 @@ case "$DISPATCHARR_ENV" in "dev") echo "DISPATCHARR_ENV is set to 'dev'. Running Development Program..." - apk add nodejs npm + apt-get update && apt-get install -y nodejs cd /app/frontend && npm install cd /app - exec /usr/sbin/uwsgi --ini uwsgi.dev.ini + exec /usr/bin/uwsgi --ini uwsgi.dev.ini ;; "aio") echo "DISPATCHARR_ENV is set to 'aio'. Running All-in-One Program..." - exec /usr/sbin/uwsgi --ini uwsgi.aio.ini + exec /usr/bin/uwsgi --ini uwsgi.aio.ini ;; *) echo "DISPATCHARR_ENV is not set or has an unexpected value. Running standalone..." - exec /usr/sbin/uwsgi --ini uwsgi.ini + exec /usr/bin/uwsgi --ini uwsgi.ini ;; esac From c9b8d60c4199412947836f09c48cf2457ed5c9f8 Mon Sep 17 00:00:00 2001 From: dekzter Date: Sun, 2 Mar 2025 10:58:45 -0500 Subject: [PATCH 08/14] missing python uwdgi package --- docker/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 2c3b97a0..9db6b017 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,7 +16,8 @@ RUN apt-get update && \ lsb-release \ python3-virtualenv \ streamlink \ - uwsgi + uwsgi \ + python3-django-uwsgi RUN \ curl -sL https://deb.nodesource.com/setup_23.x -o /tmp/nodesource_setup.sh && \ @@ -52,6 +53,7 @@ RUN \ lsb-release \ nodejs && \ apt-get clean && \ + apt-get autoremove -y && \ rm -rf \ /tmp/* \ /var/lib/apt/lists/* \ From 140e0f7b11474a37b4509c2947ef98499f214a1e Mon Sep 17 00:00:00 2001 From: dekzter Date: Sun, 2 Mar 2025 11:00:01 -0500 Subject: [PATCH 09/14] updated requirements for ai matching --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index a8c883ca..618c84b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,5 @@ yt-dlp gevent==24.11.1 django-cors-headers djangorestframework-simplejwt +rapidfuzz==3.12.1 +sentence-transformers==3.4.1 From 2812d3da5c09fc2cd6c305f0bacfaab2dbe54eea Mon Sep 17 00:00:00 2001 From: dekzter Date: Sun, 2 Mar 2025 12:52:42 -0500 Subject: [PATCH 10/14] default stream profile to default --- frontend/src/components/forms/Channel.js | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/frontend/src/components/forms/Channel.js b/frontend/src/components/forms/Channel.js index 1725b563..2b5d84e2 100644 --- a/frontend/src/components/forms/Channel.js +++ b/frontend/src/components/forms/Channel.js @@ -69,7 +69,7 @@ const Channel = ({ channel = null, isOpen, onClose }) => { channel_name: '', channel_number: '', channel_group_id: '', - stream_profile_id: '', + stream_profile_id: '0', tvg_id: '', tvg_name: '', }, @@ -79,6 +79,10 @@ const Channel = ({ channel = null, isOpen, onClose }) => { channel_group_id: Yup.string().required('Channel group is required'), }), onSubmit: async (values, { setSubmitting, resetForm }) => { + if (values.stream_profile_id == '0') { + values.stream_profile_id = null; + } + console.log(values); if (channel?.id) { await API.updateChannel({ @@ -109,21 +113,18 @@ const Channel = ({ channel = null, isOpen, onClose }) => { channel_name: channel.channel_name, channel_number: channel.channel_number, channel_group_id: channel.channel_group?.id, - stream_profile_id: channel.stream_profile_id, + stream_profile_id: channel.stream_profile_id || '0', tvg_id: channel.tvg_id, tvg_name: channel.tvg_name, }); - console.log('channel streams'); - console.log(channel.streams); + console.log(channel); const filteredStreams = streams - .filter((stream) => channel.streams.includes(stream.id)) + .filter((stream) => channel.stream_ids.includes(stream.id)) .sort( (a, b) => - channel.streams.indexOf(a.id) - channel.streams.indexOf(b.id) + channel.stream_ids.indexOf(a.id) - channel.stream_ids.indexOf(b.id) ); - console.log('filtered streams'); - console.log(filteredStreams); setChannelStreams(filteredStreams); } else { formik.resetForm(); @@ -334,7 +335,6 @@ const Channel = ({ channel = null, isOpen, onClose }) => { labelId="stream-profile-label" id="stream_profile_id" name="stream_profile_id" - label="Stream Profile (optional)" value={formik.values.stream_profile_id} onChange={formik.handleChange} onBlur={formik.handleBlur} @@ -345,6 +345,9 @@ const Channel = ({ channel = null, isOpen, onClose }) => { // helperText={formik.touched.channel_group_id && formik.errors.stream_profile_id} variant="standard" > + + Use Default + {streamProfiles.map((option, index) => ( {option.profile_name} @@ -401,7 +404,7 @@ const Channel = ({ channel = null, isOpen, onClose }) => { helperText={formik.touched.tvg_id && formik.errors.tvg_id} variant="standard" /> - + { helperText="If you have a direct image URL, set it here." /> - {/* File upload input */} { ); }; -export default Channel; \ No newline at end of file +export default Channel; From c63ddcfe7b75fec5cbfec18d785f5f320b725614 Mon Sep 17 00:00:00 2001 From: Dispatcharr Date: Sun, 2 Mar 2025 12:27:21 -0600 Subject: [PATCH 11/14] AI EPG Matching Added AI EPG matching --- apps/channels/api_views.py | 20 ++ apps/channels/tasks.py | 207 ++++++++++++++++++ frontend/src/api.js | 127 +++++------ .../src/components/tables/ChannelsTable.js | 42 +++- frontend/src/pages/Settings.js | 192 ++++++++++------ 5 files changed, 445 insertions(+), 143 deletions(-) create mode 100644 apps/channels/tasks.py diff --git a/apps/channels/api_views.py b/apps/channels/api_views.py index ea55e3e6..75772509 100644 --- a/apps/channels/api_views.py +++ b/apps/channels/api_views.py @@ -9,6 +9,8 @@ from django.shortcuts import get_object_or_404 from .models import Stream, Channel, ChannelGroup from .serializers import StreamSerializer, ChannelSerializer, ChannelGroupSerializer +from .tasks import match_epg_channels + # ───────────────────────────────────────────────────────── # 1) Stream API (CRUD) @@ -30,6 +32,7 @@ class StreamViewSet(viewsets.ModelViewSet): qs = qs.filter(channels__isnull=True) return qs + # ───────────────────────────────────────────────────────── # 2) Channel Group Management (CRUD) # ───────────────────────────────────────────────────────── @@ -38,6 +41,7 @@ class ChannelGroupViewSet(viewsets.ModelViewSet): serializer_class = ChannelGroupSerializer permission_classes = [IsAuthenticated] + # ───────────────────────────────────────────────────────── # 3) Channel Management (CRUD) # ───────────────────────────────────────────────────────── @@ -178,6 +182,7 @@ class ChannelViewSet(viewsets.ModelViewSet): # Gather current used numbers once. used_numbers = set(Channel.objects.all().values_list('channel_number', flat=True)) next_number = 1 + def get_auto_number(): nonlocal next_number while next_number in used_numbers: @@ -236,6 +241,20 @@ class ChannelViewSet(viewsets.ModelViewSet): return Response(response_data, status=status.HTTP_201_CREATED) + # ───────────────────────────────────────────────────────── + # 6) EPG Fuzzy Matching + # ───────────────────────────────────────────────────────── + @swagger_auto_schema( + method='post', + operation_description="Kick off a Celery task that tries to fuzzy-match channels with EPG data.", + responses={202: "EPG matching task initiated"} + ) + @action(detail=False, methods=['post'], url_path='match-epg') + def match_epg(self, request): + match_epg_channels.delay() + return Response({"message": "EPG matching task initiated."}, status=status.HTTP_202_ACCEPTED) + + # ───────────────────────────────────────────────────────── # 4) Bulk Delete Streams # ───────────────────────────────────────────────────────── @@ -262,6 +281,7 @@ class BulkDeleteStreamsAPIView(APIView): Stream.objects.filter(id__in=stream_ids).delete() return Response({"message": "Streams deleted successfully!"}, status=status.HTTP_204_NO_CONTENT) + # ───────────────────────────────────────────────────────── # 5) Bulk Delete Channels # ───────────────────────────────────────────────────────── diff --git a/apps/channels/tasks.py b/apps/channels/tasks.py new file mode 100644 index 00000000..c4bf8177 --- /dev/null +++ b/apps/channels/tasks.py @@ -0,0 +1,207 @@ +# apps/channels/tasks.py + +import logging +import re + +from celery import shared_task +from rapidfuzz import fuzz +from sentence_transformers import SentenceTransformer, util +from django.db import transaction + +from apps.channels.models import Channel +from apps.epg.models import EPGData +from core.models import CoreSettings # to retrieve "preferred-region" setting + +logger = logging.getLogger(__name__) + +# Load the model once at module level +SENTENCE_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" +st_model = SentenceTransformer(SENTENCE_MODEL_NAME) + +# Threshold constants +BEST_FUZZY_THRESHOLD = 70 +LOWER_FUZZY_THRESHOLD = 40 +EMBED_SIM_THRESHOLD = 0.65 + +# Common extraneous words +COMMON_EXTRANEOUS_WORDS = [ + "tv", "channel", "network", "television", + "east", "west", "hd", "uhd", "us", "usa", "not", "24/7", + "1080p", "720p", "540p", "480p", + "arabic", "latino", "film", "movie", "movies" +] + +def normalize_channel_name(name: str) -> str: + """ + A more aggressive normalization that: + - Lowercases + - Removes bracketed/parenthesized text + - Removes punctuation + - Strips extraneous words + - Collapses extra spaces + """ + if not name: + return "" + + # Lowercase + norm = name.lower() + + # Remove bracketed text + norm = re.sub(r"\[.*?\]", "", norm) + norm = re.sub(r"\(.*?\)", "", norm) + + # Remove punctuation except word chars/spaces + norm = re.sub(r"[^\w\s]", "", norm) + + # Remove extraneous tokens + tokens = norm.split() + tokens = [t for t in tokens if t not in COMMON_EXTRANEOUS_WORDS] + + # Rejoin + norm = " ".join(tokens).strip() + return norm + +@shared_task +def match_epg_channels(): + """ + Goes through all Channels and tries to find a matching EPGData row by: + 1) If channel.tvg_id is valid in EPGData, skip + 2) If channel has a tvg_id but not found in EPGData, attempt direct EPGData lookup + 3) Otherwise do name-based fuzzy ratio pass: + - add region-based bonus if region code is found in the EPG row + - if fuzzy >= BEST_FUZZY_THRESHOLD => accept + - if fuzzy in [LOWER_FUZZY_THRESHOLD..BEST_FUZZY_THRESHOLD) => do embedding check + - else skip + 4) Log summary + """ + logger.info("Starting EPG matching logic...") + + # Try to get user's preferred region from CoreSettings + try: + region_obj = CoreSettings.objects.get(key="preferred-region") + region_code = region_obj.value.strip().lower() # e.g. "us" + except CoreSettings.DoesNotExist: + region_code = None + + # 1) Gather EPG rows + all_epg = list(EPGData.objects.all()) + epg_rows = [] + for e in all_epg: + epg_rows.append({ + "epg_id": e.id, + "tvg_id": e.tvg_id or "", # e.g. "Fox News.us" + "raw_name": e.channel_name, + "norm_name": normalize_channel_name(e.channel_name), + }) + + # 2) Pre-encode embeddings if possible + epg_embeddings = None + if any(row["norm_name"] for row in epg_rows): + epg_embeddings = st_model.encode( + [row["norm_name"] for row in epg_rows], + convert_to_tensor=True + ) + + matched_channels = [] + + with transaction.atomic(): + for chan in Channel.objects.all(): + # A) Skip if channel.tvg_id is valid + if chan.tvg_id and EPGData.objects.filter(tvg_id=chan.tvg_id).exists(): + continue + + # B) If channel has a tvg_id but not in EPG, do direct lookup + if chan.tvg_id: + epg_match = EPGData.objects.filter(tvg_id=chan.tvg_id).first() + if epg_match: + logger.info( + f"Channel {chan.id} '{chan.channel_name}' => found EPG by tvg_id={chan.tvg_id}" + ) + continue + + # C) No valid tvg_id => name-based matching + fallback_name = chan.tvg_name.strip() if chan.tvg_name else chan.channel_name + norm_chan = normalize_channel_name(fallback_name) + if not norm_chan: + logger.info( + f"Channel {chan.id} '{chan.channel_name}' => empty after normalization, skipping" + ) + continue + + best_score = 0 + best_epg = None + + for row in epg_rows: + if not row["norm_name"]: + continue + # Base fuzzy ratio + base_score = fuzz.ratio(norm_chan, row["norm_name"]) + + # If we have a region_code, add a small bonus if the epg row has that region + # e.g. tvg_id or raw_name might contain ".us" or "us" + bonus = 0 + if region_code: + # example: if region_code is "us" and row["tvg_id"] ends with ".us" + # or row["raw_name"] has "us" in it, etc. + # We'll do a naive check: + combined_text = row["tvg_id"].lower() + " " + row["raw_name"].lower() + if region_code in combined_text: + bonus = 15 # pick a small bonus + + score = base_score + bonus + + if score > best_score: + best_score = score + best_epg = row + + if not best_epg: + logger.info(f"Channel {chan.id} '{fallback_name}' => no EPG match at all.") + continue + + # E) Decide acceptance + if best_score >= BEST_FUZZY_THRESHOLD: + # Accept + chan.tvg_id = best_epg["tvg_id"] + chan.save() + matched_channels.append((chan.id, fallback_name, best_epg["tvg_id"])) + logger.info( + f"Channel {chan.id} '{fallback_name}' => matched tvg_id={best_epg['tvg_id']} (score={best_score})" + ) + elif best_score >= LOWER_FUZZY_THRESHOLD and epg_embeddings is not None: + # borderline => do embedding + chan_embedding = st_model.encode(norm_chan, convert_to_tensor=True) + sim_scores = util.cos_sim(chan_embedding, epg_embeddings)[0] + top_index = int(sim_scores.argmax()) + top_value = float(sim_scores[top_index]) + + if top_value >= EMBED_SIM_THRESHOLD: + matched_epg = epg_rows[top_index] + chan.tvg_id = matched_epg["tvg_id"] + chan.save() + matched_channels.append((chan.id, fallback_name, matched_epg["tvg_id"])) + logger.info( + f"Channel {chan.id} '{fallback_name}' => matched EPG tvg_id={matched_epg['tvg_id']} " + f"(fuzzy={best_score}, cos-sim={top_value:.2f})" + ) + else: + logger.info( + f"Channel {chan.id} '{fallback_name}' => fuzzy={best_score}, " + f"cos-sim={top_value:.2f} < {EMBED_SIM_THRESHOLD}, skipping" + ) + else: + # no match + logger.info( + f"Channel {chan.id} '{fallback_name}' => fuzzy={best_score} < {LOWER_FUZZY_THRESHOLD}, skipping" + ) + + # Final summary + total_matched = len(matched_channels) + if total_matched: + logger.info(f"Match Summary: {total_matched} channel(s) matched.") + for (cid, cname, tvg) in matched_channels: + logger.info(f" - Channel ID={cid}, Name='{cname}' => tvg_id='{tvg}'") + else: + logger.info("No new channels were matched.") + + logger.info("Finished EPG matching logic.") + return f"Done. Matched {total_matched} channel(s)." diff --git a/frontend/src/api.js b/frontend/src/api.js index f840d1ab..0ed976f0 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -1,3 +1,4 @@ +// src/api.js (updated) import useAuthStore from './store/auth'; import useChannelsStore from './store/channels'; import useUserAgentsStore from './store/userAgents'; @@ -7,18 +8,17 @@ import useStreamsStore from './store/streams'; import useStreamProfilesStore from './store/streamProfiles'; import useSettingsStore from './store/settings'; -// const axios = Axios.create({ -// withCredentials: true, -// }); - +// If needed, you can set a base host or keep it empty if relative requests const host = ''; -export const getAuthToken = async () => { - const token = await useAuthStore.getState().getToken(); // Assuming token is stored in Zustand store - return token; -}; - export default class API { + /** + * A static method so we can do: await API.getAuthToken() + */ + static async getAuthToken() { + return await useAuthStore.getState().getToken(); + } + static async login(username, password) { const response = await fetch(`${host}/api/accounts/token/`, { method: 'POST', @@ -31,11 +31,11 @@ export default class API { return await response.json(); } - static async refreshToken(refreshToken) { + static async refreshToken(refresh) { const response = await fetch(`${host}/api/accounts/token/refresh/`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ refresh: refreshToken }), + body: JSON.stringify({ refresh }), }); const retval = await response.json(); @@ -54,7 +54,7 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/`, { headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, }, }); @@ -66,7 +66,7 @@ export default class API { const response = await fetch(`${host}/api/channels/groups/`, { headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, }, }); @@ -78,7 +78,7 @@ export default class API { const response = await fetch(`${host}/api/channels/groups/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -97,7 +97,7 @@ export default class API { const response = await fetch(`${host}/api/channels/groups/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -114,6 +114,7 @@ export default class API { static async addChannel(channel) { let body = null; if (channel.logo_file) { + // Must send FormData for file upload body = new FormData(); for (const prop in channel) { body.append(prop, channel[prop]); @@ -127,7 +128,7 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, ...(channel.logo_file ? {} : { @@ -149,7 +150,7 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/${id}/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -162,7 +163,7 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/bulk-delete/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ channel_ids }), @@ -176,7 +177,7 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -195,26 +196,22 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/assign/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ channel_order: channelIds }), }); - // The backend returns something like { "message": "Channels have been auto-assigned!" } if (!response.ok) { - // If you want to handle errors gracefully: const text = await response.text(); throw new Error(`Assign channels failed: ${response.status} => ${text}`); } - // Usually it has a { message: "..."} or similar const retval = await response.json(); - // If you want to automatically refresh the channel list in Zustand: + // Optionally refresh the channel list in Zustand await useChannelsStore.getState().fetchChannels(); - // Return the entire JSON result (so the caller can see the "message") return retval; } @@ -222,7 +219,7 @@ export default class API { const response = await fetch(`${host}/api/channels/channels/from-stream/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -242,7 +239,7 @@ export default class API { { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -261,7 +258,7 @@ export default class API { const response = await fetch(`${host}/api/channels/streams/`, { headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, }, }); @@ -273,7 +270,7 @@ export default class API { const response = await fetch(`${host}/api/channels/streams/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -292,7 +289,7 @@ export default class API { const response = await fetch(`${host}/api/channels/streams/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -310,7 +307,7 @@ export default class API { const response = await fetch(`${host}/api/channels/streams/${id}/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -322,7 +319,7 @@ export default class API { const response = await fetch(`${host}/api/channels/streams/bulk-delete/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ stream_ids: ids }), @@ -335,7 +332,7 @@ export default class API { const response = await fetch(`${host}/api/core/useragents/`, { headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, }, }); @@ -347,7 +344,7 @@ export default class API { const response = await fetch(`${host}/api/core/useragents/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -366,7 +363,7 @@ export default class API { const response = await fetch(`${host}/api/core/useragents/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -384,7 +381,7 @@ export default class API { const response = await fetch(`${host}/api/core/useragents/${id}/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -395,7 +392,7 @@ export default class API { static async getPlaylist(id) { const response = await fetch(`${host}/api/m3u/accounts/${id}/`, { headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -407,7 +404,7 @@ export default class API { static async getPlaylists() { const response = await fetch(`${host}/api/m3u/accounts/`, { headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -420,7 +417,7 @@ export default class API { const response = await fetch(`${host}/api/m3u/accounts/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -438,7 +435,7 @@ export default class API { const response = await fetch(`${host}/api/m3u/refresh/${id}/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -451,7 +448,7 @@ export default class API { const response = await fetch(`${host}/api/m3u/refresh/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -464,7 +461,7 @@ export default class API { const response = await fetch(`${host}/api/m3u/accounts/${id}/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -477,7 +474,7 @@ export default class API { const response = await fetch(`${host}/api/m3u/accounts/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -494,7 +491,7 @@ export default class API { static async getEPGs() { const response = await fetch(`${host}/api/epg/sources/`, { headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -503,18 +500,8 @@ export default class API { return retval; } - static async refreshPlaylist(id) { - const response = await fetch(`${host}/api/m3u/refresh/${id}/`, { - method: 'POST', - headers: { - Authorization: `Bearer ${await getAuthToken()}`, - 'Content-Type': 'application/json', - }, - }); - - const retval = await response.json(); - return retval; - } + // Notice there's a duplicated "refreshPlaylist" method above; + // you might want to rename or remove one if it's not needed. static async addEPG(values) { let body = null; @@ -532,7 +519,7 @@ export default class API { const response = await fetch(`${host}/api/epg/sources/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, ...(values.epg_file ? {} : { @@ -554,7 +541,7 @@ export default class API { const response = await fetch(`${host}/api/epg/sources/${id}/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -566,7 +553,7 @@ export default class API { const response = await fetch(`${host}/api/epg/import/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify({ id }), @@ -579,7 +566,7 @@ export default class API { static async getStreamProfiles() { const response = await fetch(`${host}/api/core/streamprofiles/`, { headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -592,7 +579,7 @@ export default class API { const response = await fetch(`${host}/api/core/streamprofiles/`, { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -610,7 +597,7 @@ export default class API { const response = await fetch(`${host}/api/core/streamprofiles/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -628,7 +615,7 @@ export default class API { const response = await fetch(`${host}/api/core/streamprofiles/${id}/`, { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -639,7 +626,7 @@ export default class API { static async getGrid() { const response = await fetch(`${host}/api/epg/grid/`, { headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, }); @@ -654,7 +641,7 @@ export default class API { { method: 'POST', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(values), @@ -663,7 +650,7 @@ export default class API { const retval = await response.json(); if (retval.id) { - // Fetch m3u account to update it with its new playlists + // Refresh the playlist const playlist = await API.getPlaylist(accountId); usePlaylistsStore .getState() @@ -679,7 +666,7 @@ export default class API { { method: 'DELETE', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, } @@ -696,7 +683,7 @@ export default class API { { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), @@ -711,7 +698,7 @@ export default class API { const response = await fetch(`${host}/api/core/settings/`, { headers: { 'Content-Type': 'application/json', - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, }, }); @@ -724,7 +711,7 @@ export default class API { const response = await fetch(`${host}/api/core/settings/${id}/`, { method: 'PUT', headers: { - Authorization: `Bearer ${await getAuthToken()}`, + Authorization: `Bearer ${await API.getAuthToken()}`, 'Content-Type': 'application/json', }, body: JSON.stringify(payload), diff --git a/frontend/src/components/tables/ChannelsTable.js b/frontend/src/components/tables/ChannelsTable.js index e271511c..b09d9566 100644 --- a/frontend/src/components/tables/ChannelsTable.js +++ b/frontend/src/components/tables/ChannelsTable.js @@ -24,13 +24,14 @@ import { SwapVert as SwapVertIcon, LiveTv as LiveTvIcon, ContentCopy, + Tv as TvIcon, // <-- ADD THIS IMPORT } from '@mui/icons-material'; import API from '../../api'; import ChannelForm from '../forms/Channel'; import { TableHelper } from '../../helpers'; import utils from '../../utils'; import logo from '../../images/logo.png'; -import useVideoStore from '../../store/useVideoStore'; // NEW import +import useVideoStore from '../../store/useVideoStore'; const ChannelsTable = () => { const [channel, setChannel] = useState(null); @@ -116,6 +117,7 @@ const ChannelsTable = () => { 4, selected.map((chan) => () => deleteChannel(chan.original.id)) ); + // If you have a real bulk-delete endpoint, call it here: // await API.deleteChannels(selected.map((sel) => sel.id)); setIsLoading(false); }; @@ -144,6 +146,32 @@ const ChannelsTable = () => { } }; + // ───────────────────────────────────────────────────────── + // The new "Match EPG" button logic + // ───────────────────────────────────────────────────────── + const matchEpg = async () => { + try { + // Hit our new endpoint that triggers the fuzzy matching Celery task + const resp = await fetch('/api/channels/channels/match-epg/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${await API.getAuthToken()}`, + }, + }); + + if (resp.ok) { + setSnackbarMessage('EPG matching task started!'); + } else { + const text = await resp.text(); + setSnackbarMessage(`Failed to start EPG matching: ${text}`); + } + } catch (err) { + setSnackbarMessage(`Error: ${err.message}`); + } + setSnackbarOpen(true); + }; + const closeChannelForm = () => { setChannel(null); setChannelModalOpen(false); @@ -294,6 +322,18 @@ const ChannelsTable = () => { + {/* Our brand-new button for EPG matching */} + + + + + +