diff --git a/apps/proxy/hls_proxy b/apps/proxy/hls_proxy index 34dfb1ea..380e464d 100644 --- a/apps/proxy/hls_proxy +++ b/apps/proxy/hls_proxy @@ -15,8 +15,9 @@ import m3u8 import time from urllib.parse import urlparse, urljoin import argparse -from typing import Optional +from typing import Optional, Dict, List, Set, Deque import sys +import os # Initialize Flask app app = Flask(__name__) @@ -33,13 +34,278 @@ class Config: MAX_SEGMENTS = 16 # Maximum segments to store WINDOW_SIZE = 12 # Number of segments in manifest window INITIAL_SEGMENTS = 3 # Initial segments to buffer before playback + DEFAULT_USER_AGENT = 'VLC/3.0.20 LibVLC/3.0.20' + +class StreamBuffer: + """ + Manages buffering of stream segments with thread-safe access. + + Attributes: + buffer (Dict[int, bytes]): Maps sequence numbers to segment data + lock (threading.Lock): Thread safety for buffer access + + Features: + - Thread-safe segment storage and retrieval + - Automatic cleanup of old segments + - Sequence number based indexing + """ + + def __init__(self): + self.buffer: Dict[int, bytes] = {} # Maps sequence numbers to segment data + self.lock: threading.Lock = threading.Lock() + + def __getitem__(self, key: int) -> Optional[bytes]: + """Get segment data by sequence number""" + return self.buffer.get(key) + + def __setitem__(self, key: int, value: bytes): + """Store segment data by sequence number""" + self.buffer[key] = value + # Cleanup old segments if we exceed MAX_SEGMENTS + if len(self.buffer) > Config.MAX_SEGMENTS: + keys = sorted(self.buffer.keys()) + # Keep the most recent MAX_SEGMENTS + to_remove = keys[:-Config.MAX_SEGMENTS] + for k in to_remove: + del self.buffer[k] + + def __contains__(self, key: int) -> bool: + """Check if sequence number exists in buffer""" + return key in self.buffer + + def keys(self) -> List[int]: + """Get list of available sequence numbers""" + return list(self.buffer.keys()) + + def cleanup(self, keep_sequences: List[int]): + """Remove segments not in keep list""" + for seq in list(self.buffer.keys()): + if seq not in keep_sequences: + del self.buffer[seq] + +class ClientManager: + """Manages client connections""" + + def __init__(self): + """ + Initialize client connection tracking. + + Attributes: + active_clients (Set[int]): Set of active client IDs + lock (threading.Lock): Thread safety for client operations + """ + self.active_clients: Set[int] = set() + self.lock: threading.Lock = threading.Lock() + + def add_client(self, client_id: int): + """Add new client connection""" + with self.lock: + self.active_clients.add(client_id) + logging.info(f"New client connected (total: {len(self.active_clients)})") + + def remove_client(self, client_id: int) -> int: + """Remove client and return remaining count""" + with self.lock: + self.active_clients.remove(client_id) + remaining = len(self.active_clients) + logging.info(f"Client disconnected (remaining: {remaining})") + return remaining + + def has_clients(self) -> bool: + """Check if any clients are connected""" + return len(self.active_clients) > 0 + +class StreamManager: + """ + Manages HLS stream state and switching logic. + + Attributes: + current_url (str): Current stream URL + channel_id (str): Unique channel identifier + running (bool): Stream activity flag + next_sequence (int): Next sequence number to assign + buffered_sequences (set): Currently buffered sequence numbers + source_changes (set): Sequences where stream source changed + + Features: + - Stream URL management + - Sequence number assignment + - Discontinuity tracking + - Thread coordination + - Buffer state management + """ + def __init__(self, initial_url: str, channel_id: str, user_agent: Optional[str] = None): + # Stream state + self.current_url = initial_url + self.channel_id = channel_id + self.user_agent = user_agent or Config.DEFAULT_USER_AGENT + self.running = True + self.switching_stream = False + + # Sequence tracking + self.next_sequence = 0 + self.highest_sequence = 0 + self.buffered_sequences = set() + self.downloaded_sources = {} + self.segment_durations = {} + + # Source tracking + self.current_source = None + self.source_changes = set() + self.stream_switch_count = 0 + + # Threading + self.fetcher = None + self.fetch_thread = None + self.url_changed = threading.Event() + + # Add manifest info + self.target_duration = 10.0 # Default, will be updated from manifest + self.manifest_version = 3 # Default, will be updated from manifest + + logging.info(f"Initialized stream manager for channel {channel_id}") + + def update_url(self, new_url: str) -> bool: + """ + Handle stream URL changes with proper discontinuity marking. + + Args: + new_url: New stream URL to switch to + + Returns: + bool: True if URL changed, False if unchanged + + Side effects: + - Sets switching_stream flag + - Updates current_url + - Maintains sequence numbering + - Marks discontinuity point + - Signals fetch thread + """ + if new_url != self.current_url: + with buffer_lock: + self.switching_stream = True + self.current_url = new_url + + # Continue sequence numbering from last sequence + if self.buffered_sequences: + self.next_sequence = max(self.buffered_sequences) + 1 + + # Mark discontinuity at next sequence + self.source_changes.add(self.next_sequence) + + logging.info(f"Stream switch - next sequence will start at {self.next_sequence}") + + # Clear state but maintain sequence numbers + self.downloaded_sources.clear() + self.segment_durations.clear() + self.current_source = None + + # Signal thread to switch URL + self.url_changed.set() + + return True + return False + + def get_next_sequence(self, source_id: str) -> Optional[int]: + """ + Assign sequence numbers to segments with source change detection. + + Args: + source_id: Unique identifier for segment source + + Returns: + int: Next available sequence number + None: If segment already downloaded + + Side effects: + - Updates buffered sequences set + - Tracks source changes for discontinuity + - Maintains sequence numbering + """ + if source_id in self.downloaded_sources: + return None + + seq = self.next_sequence + while (seq in self.buffered_sequences): + seq += 1 + + # Track source changes for discontinuity markers + source_prefix = source_id.split('_')[0] + if not self.switching_stream and self.current_source and self.current_source != source_prefix: + self.source_changes.add(seq) + logging.debug(f"Source change detected at sequence {seq}") + self.current_source = source_prefix + + # Update tracking + self.downloaded_sources[source_id] = seq + self.buffered_sequences.add(seq) + self.next_sequence = seq + 1 + self.highest_sequence = max(self.highest_sequence, seq) + + return seq + + def _fetch_loop(self): + """Background thread for continuous stream fetching""" + while self.running: + try: + fetcher = StreamFetcher(self, self.buffer) + fetch_stream(fetcher, self.url_changed, self.next_sequence) + except Exception as e: + logging.error(f"Stream error: {e}") + time.sleep(5) # Wait before retry + + self.url_changed.clear() + + def start(self): + """Start the background fetch thread""" + if not self.fetch_thread or not self.fetch_thread.is_alive(): + self.running = True + self.fetch_thread = threading.Thread( + target=self._fetch_loop, + name="StreamFetcher", + daemon=True # Thread will exit when main program does + ) + self.fetch_thread.start() + logging.info("Stream manager started") + + def stop(self): + """Stop the background fetch thread""" + self.running = False + if self.fetch_thread and self.fetch_thread.is_alive(): + self.url_changed.set() # Signal thread to exit + self.fetch_thread.join(timeout=5) # Wait up to 5 seconds + logging.info("Stream manager stopped") class StreamFetcher: - """Handles HTTP requests for stream segments with connection pooling""" - def __init__(self, stream_url): - self.stream_url = stream_url + """ + Handles HTTP requests for stream segments with connection pooling. + + Attributes: + manager (StreamManager): Associated stream manager instance + buffer (StreamBuffer): Buffer for storing segments + session (requests.Session): Persistent HTTP session + redirect_cache (dict): Cache for redirect responses + + Features: + - Connection pooling and reuse + - Redirect caching + - Rate limiting + - Automatic retries + - Host fallback + """ + def __init__(self, manager: StreamManager, buffer: StreamBuffer): + self.manager = manager + self.buffer = buffer + self.stream_url = manager.current_url self.session = requests.Session() + # Configure session headers + self.session.headers.update({ + 'User-Agent': manager.user_agent, + 'Connection': 'keep-alive' + }) + # Set up connection pooling adapter = requests.adapters.HTTPAdapter( pool_connections=2, # Number of connection pools @@ -57,9 +323,26 @@ class StreamFetcher: self.min_request_interval = 0.05 # Minimum time between requests self.last_host = None # Cache last successful host self.redirect_cache = {} # Cache redirect responses + self.redirect_cache_limit = 1000 + + def cleanup_redirect_cache(self): + """Remove old redirect cache entries""" + if len(self.redirect_cache) > self.redirect_cache_limit: + self.redirect_cache.clear() - def get_base_host(self, url): - """Extract base host from URL using urlparse""" + def get_base_host(self, url: str) -> str: + """ + Extract base host from URL. + + Args: + url: Full URL to parse + + Returns: + str: Base host in format 'scheme://hostname' + + Example: + 'http://example.com/path' -> 'http://example.com' + """ try: parsed = urlparse(url) return f"{parsed.scheme}://{parsed.netloc}" @@ -67,8 +350,25 @@ class StreamFetcher: logging.error(f"Error extracting base host: {e}") return url - def download(self, url): - """Download content with connection reuse""" + def download(self, url: str) -> tuple[bytes, str]: + """ + Download content with connection reuse and redirect handling. + + Args: + url: URL to download from + + Returns: + tuple containing: + bytes: Downloaded content + str: Final URL after any redirects + + Features: + - Connection pooling/reuse + - Redirect caching + - Rate limiting + - Host fallback on failure + - Automatic retries + """ now = time.time() wait_time = self.last_request_time + self.min_request_interval - now if (wait_time > 0): @@ -102,69 +402,114 @@ class StreamFetcher: return self.download(new_url) raise -def analyze_ts_packet(data: bytes) -> dict: - """ - Analyze a single MPEG-TS packet (188 bytes) - - Args: - data: Raw packet bytes + def fetch_loop(self): + """ + Main fetching loop that continuously downloads stream content. - Returns: - dict with packet analysis: - - sync_valid: True if sync byte is 0x47 - - transport_error: True if TEI bit is set - - payload_start: True if PUSI bit is set - - pid: Packet ID (13 bits) - - hex_dump: First 16 bytes for debugging - """ - # Validate minimum packet size - if len(data) < 188: - return { - 'sync_valid': False, - 'transport_error': True, - 'payload_start': False, - 'pid': 0, - 'error': 'Packet too short' - } - - # Verify sync byte (0x47) - if data[0] != 0x47: - return { - 'sync_valid': False, - 'transport_error': True, - 'payload_start': False, - 'pid': 0, - 'error': 'Invalid sync byte' - } - - # Extract packet header fields - transport_error = (data[1] & 0x80) != 0 # Transport Error Indicator - payload_start = (data[1] & 0x40) != 0 # Payload Unit Start Indicator - pid = ((data[1] & 0x1F) << 8) | data[2] # Packet ID (13 bits) - - # Create hex dump for debugging - hex_dump = ' '.join(f'{b:02x}' for b in data[:16]) - - return { - 'sync_valid': True, - 'transport_error': transport_error, - 'payload_start': payload_start, - 'pid': pid, - 'hex_dump': hex_dump, - 'packet_size': len(data) - } + Features: + - Automatic manifest updates + - Rate-limited downloads + - Exponential backoff on errors + - Stream switch handling + - Segment validation + + Error handling: + - HTTP 509 rate limiting + - Connection drops + - Invalid segments + - Manifest parsing errors + + Thread safety: + Coordinates with StreamManager and StreamBuffer + using proper locking mechanisms + """ + retry_delay = 1 + max_retry_delay = 8 + last_manifest_time = 0 + + while self.manager.running: + try: + now = time.time() + + # Get manifest data + try: + manifest_data, final_url = self.download(self.manager.current_url) + manifest = m3u8.loads(manifest_data.decode()) + + # Reset retry delay on successful fetch + retry_delay = 1 + + except requests.exceptions.HTTPError as e: + if e.response.status_code == 509: + logging.warning("Rate limit exceeded, backing off...") + time.sleep(retry_delay) + retry_delay = min(retry_delay * 2, max_retry_delay) + continue + raise + + # Update manifest info + if manifest.target_duration: + self.manager.target_duration = float(manifest.target_duration) + if manifest.version: + self.manager.manifest_version = manifest.version + + if not manifest.segments: + logging.warning("No segments in manifest") + time.sleep(retry_delay) + continue + + # Calculate proper manifest polling interval + target_duration = float(manifest.target_duration) + manifest_interval = target_duration * 0.5 # Poll at half the segment duration + + # Process latest segment + latest_segment = manifest.segments[-1] + try: + segment_url = urljoin(final_url, latest_segment.uri) + segment_data, _ = self.download(segment_url) + + verification = verify_segment(segment_data) + if not verification.get('valid', False): + logging.warning(f"Invalid segment: {verification.get('error')}") + continue + + # Store segment with proper locking + with self.buffer.lock: + seq = self.manager.next_sequence + self.buffer[seq] = segment_data + self.manager.segment_durations[seq] = float(latest_segment.duration) + self.manager.next_sequence += 1 + logging.debug(f"Stored segment {seq} (duration: {latest_segment.duration}s)") + + except Exception as e: + logging.error(f"Segment download error: {e}") + continue + + # Update last manifest time and wait for next interval + last_manifest_time = now + time.sleep(manifest_interval) + + except Exception as e: + logging.error(f"Fetch error: {e}") + time.sleep(retry_delay) + retry_delay = min(retry_delay * 2, max_retry_delay) def get_segment_sequence(segment_uri: str) -> Optional[int]: """ - Extract sequence number from segment URI + Extract sequence number from segment URI pattern. Args: segment_uri: Segment filename or path Returns: - int: Sequence number if found - None: If no sequence can be extracted + int: Extracted sequence number if found + None: If no valid sequence number can be extracted + + Handles common patterns like: + - Numerical sequences (e.g., segment_1234.ts) + - Complex patterns with stream IDs (e.g., stream_123_456.ts) """ + try: # Try numerical sequence (e.g., 1038_3693.ts) if '_' in segment_uri: @@ -176,18 +521,25 @@ def get_segment_sequence(segment_uri: str) -> Optional[int]: # Update verify_segment with more thorough checks def verify_segment(data: bytes) -> dict: """ - Verify MPEG-TS segment integrity + Verify MPEG-TS segment integrity and structure. Args: - data: Raw segment bytes + data: Raw segment data bytes Returns: - dict with verification results: - - valid: True if segment passes all checks - - packets: Number of valid packets found - - size: Total segment size in bytes - - error: Description if validation fails + dict containing: + valid (bool): True if segment passes all checks + packets (int): Number of valid packets found + size (int): Total segment size in bytes + error (str): Description if validation fails + + Checks: + - Minimum size requirements + - Packet size alignment + - Sync byte presence + - Transport error indicators """ + # Check minimum size if len(data) < 188: return {'valid': False, 'error': 'Segment too short'} @@ -223,120 +575,22 @@ def verify_segment(data: bytes) -> dict: 'size': len(data) } -class StreamManager: - """Manages HLS stream state and switching logic""" - def __init__(self, initial_url: str): - # Stream state - self.current_url = initial_url - self.running = True - self.switching_stream = False - - # Sequence tracking - self.next_sequence = 0 - self.highest_sequence = 0 - self.buffered_sequences = set() - self.downloaded_sources = {} - self.segment_durations = {} - - # Source tracking - self.current_source = None - self.source_changes = set() - self.stream_switch_count = 0 - - # Threading - self.fetcher = None - self.fetch_thread = None - self.url_changed = threading.Event() - - def update_url(self, new_url: str) -> bool: - """Handle stream URL changes with proper discontinuity marking""" - if new_url != self.current_url: - with buffer_lock: - self.switching_stream = True - self.current_url = new_url - - # Set sequence numbers for stream switch - if segment_buffers: - self.highest_sequence = max(segment_buffers.keys()) - self.next_sequence = self.highest_sequence + 1 - # Mark discontinuity at first segment of new stream - self.source_changes = {self.next_sequence + 1} - else: - self.stream_switch_count += 1 - self.next_sequence = self.stream_switch_count * 1000 - self.source_changes = {self.next_sequence} - - logging.info(f"Stream switch - next sequence will start at {self.next_sequence}") - - # Clear state but maintain sequence numbers - self.downloaded_sources.clear() - self.segment_durations.clear() - self.current_source = None - - # Signal thread to switch URL - self.url_changed.set() - - return True - return False - - def get_next_sequence(self, source_id): - """Assign sequence numbers to segments with source change detection""" - if source_id in self.downloaded_sources: - return None - - seq = self.next_sequence - while seq in self.buffered_sequences: - seq += 1 - - # Track source changes for discontinuity markers - source_prefix = source_id.split('_')[0] - if not self.switching_stream and self.current_source and self.current_source != source_prefix: - self.source_changes.add(seq) - logging.debug(f"Source change detected at sequence {seq}") - self.current_source = source_prefix - - # Update tracking - self.downloaded_sources[source_id] = seq - self.buffered_sequences.add(seq) - self.next_sequence = seq + 1 - self.highest_sequence = max(self.highest_sequence, seq) - - return seq - - def _fetch_loop(self): - """Background thread for continuous stream fetching""" - while self.running: - try: - self.fetcher = StreamFetcher(self.current_url) - fetch_stream(self.fetcher, self.url_changed, self.next_sequence) - except Exception as e: - logging.error(f"Stream error: {e}") - time.sleep(5) # Wait before retry - - self.url_changed.clear() - - def start(self): - """Start the background fetch thread""" - if not self.fetch_thread or not self.fetch_thread.is_alive(): - self.running = True - self.fetch_thread = threading.Thread( - target=self._fetch_loop, - name="StreamFetcher", - daemon=True # Thread will exit when main program does - ) - self.fetch_thread.start() - logging.info("Stream manager started") - - def stop(self): - """Stop the background fetch thread""" - self.running = False - if self.fetch_thread and self.fetch_thread.is_alive(): - self.url_changed.set() # Signal thread to exit - self.fetch_thread.join(timeout=5) # Wait up to 5 seconds - logging.info("Stream manager stopped") - def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequence: int = 0): - global manifest_buffer, segment_buffers + """ + Main streaming function that handles manifest updates and segment downloads. + + Args: + fetcher: StreamFetcher instance to handle HTTP requests + stop_event: Threading event to signal when to stop fetching + start_sequence: Initial sequence number to start from + + The function implements the core HLS fetching logic: + - Fetches and parses manifest files + - Downloads new segments when they become available + - Handles stream switches with proper discontinuity marking + - Maintains buffer state and segment sequence numbering + """ + # Remove global stream_manager reference retry_delay = 1 max_retry_delay = 8 last_segment_time = 0 @@ -366,22 +620,22 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ manifest_content = manifest_data.decode() new_segments = {} - if stream_manager.switching_stream: + if fetcher.manager.switching_stream: # Use fetcher.manager instead of stream_manager # Stream switch - only get latest segment manifest_segments = [manifest.segments[-1]] - seq_start = stream_manager.next_sequence + seq_start = fetcher.manager.next_sequence max_segments = 1 logging.debug(f"Processing stream switch - getting latest segment at sequence {seq_start}") elif not buffer_initialized: # Initial buffer manifest_segments = manifest.segments[-Config.INITIAL_SEGMENTS:] - seq_start = stream_manager.next_sequence + seq_start = fetcher.manager.next_sequence max_segments = Config.INITIAL_SEGMENTS logging.debug(f"Starting initial buffer at sequence {seq_start}") else: # Normal operation manifest_segments = [manifest.segments[-1]] - seq_start = stream_manager.next_sequence + seq_start = fetcher.manager.next_sequence max_segments = 1 # Map segments @@ -391,7 +645,7 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ break source_id = segment.uri.split('/')[-1].split('.')[0] - next_seq = stream_manager.get_next_sequence(source_id) + next_seq = fetcher.manager.get_next_sequence(source_id) if next_seq is not None: duration = float(segment.duration) @@ -400,7 +654,7 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ 'duration': duration, 'source_id': source_id } - stream_manager.segment_durations[next_seq] = duration + fetcher.manager.segment_durations[next_seq] = duration segments_mapped += 1 manifest_buffer = manifest_content @@ -420,8 +674,8 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ segment_buffers[sequence_id] = segment_data logging.debug(f"Downloaded and verified segment {sequence_id} (packets: {validation['packets']})") - if stream_manager.switching_stream: - stream_manager.switching_stream = False + if fetcher.manager.switching_stream: + fetcher.manager.switching_stream = False stop_event.set() # Force fetcher restart with new URL break elif not buffer_initialized and len(segment_buffers) >= Config.INITIAL_SEGMENTS: @@ -442,118 +696,23 @@ def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequ retry_delay = min(retry_delay * 2, max_retry_delay) manifest_update_needed = True -# Flask Routes for HLS Proxy Server -@app.route('/stream.m3u8') -def master_playlist(): - """ - Serve the HLS master playlist - - Handles: - - Initial buffering state - - Ongoing playback with sliding window - - Discontinuity markers for stream switches - - Dynamic segment durations - """ - with buffer_lock: - # Verify buffer state - if not manifest_buffer or not segment_buffers: - logging.warning("No manifest or segments available yet") - return '', 404 - - available = sorted(segment_buffers.keys()) - if not available: - logging.warning("No segments available") - return '', 404 - - manifest = m3u8.loads(manifest_buffer) - max_seq = max(available) - - # Calculate window bounds - if len(available) <= Config.INITIAL_SEGMENTS: - # During initial buffering, show all segments - min_seq = min(available) - else: - # For ongoing playback, maintain sliding window - min_seq = max( - min(available), - max_seq - Config.WINDOW_SIZE + 1 - ) - - # Build manifest with proper tags - new_manifest = ['#EXTM3U'] - new_manifest.append('#EXT-X-VERSION:3') - new_manifest.append(f'#EXT-X-MEDIA-SEQUENCE:{min_seq}') - new_manifest.append(f'#EXT-X-TARGETDURATION:{int(manifest.target_duration)}') - - # Filter segments within window - window_segments = [s for s in available if min_seq <= s <= max_seq] - - # Add segments with discontinuity handling - for seq in window_segments: - # Mark stream switches with discontinuity - if seq in stream_manager.source_changes: - new_manifest.append('#EXT-X-DISCONTINUITY') - logging.debug(f"Added discontinuity marker before segment {seq}") - - # Use actual segment duration or fallback to target - duration = stream_manager.segment_durations.get(seq, manifest.target_duration) - new_manifest.append(f'#EXTINF:{duration:.3f},') - new_manifest.append(f'/segments/{seq}.ts') - - manifest_content = '\n'.join(new_manifest) - logging.debug(f"Serving manifest with segments {min_seq}-{max_seq} (window: {len(window_segments)})") - return Response(manifest_content, content_type='application/vnd.apple.mpegurl') -@app.route('/segments/') -def get_segment(segment_name): - """ - Serve individual MPEG-TS segments - - Args: - segment_name: Segment filename (e.g., '123.ts') - - Returns: - MPEG-TS segment data or 404 if not found - """ - try: - segment_id = int(segment_name.split('.')[0]) - with buffer_lock: - if segment_id in segment_buffers: - available = sorted(segment_buffers.keys()) - logging.debug(f"Client requested segment {segment_id} (buffer: {min(available)}-{max(available)})") - return Response(segment_buffers[segment_id], content_type='video/MP2T') - - logging.warning( - f"Segment {segment_id} not found. " - f"Available: {min(segment_buffers.keys()) if segment_buffers else 'none'}" - f"-{max(segment_buffers.keys()) if segment_buffers else 'none'}" - ) - except Exception as e: - logging.error(f"Error serving segment {segment_name}: {e}") - return '', 404 - -@app.route('/change_stream', methods=['POST']) -def change_stream(): - """ - Handle stream URL changes via HTTP POST - - Expected JSON body: - {"url": "new_stream_url"} - - Returns: - JSON response indicating success/failure - """ - new_url = request.json.get('url') - if not new_url: - return jsonify({'error': 'No URL provided'}), 400 - - if stream_manager.update_url(new_url): - return jsonify({'message': 'Stream URL updated', 'url': new_url}) - return jsonify({'message': 'URL unchanged', 'url': new_url}) @app.before_request def log_request_info(): - """Log client connections and important requests""" + """ + Log client connections and important requests. + + Logs: + INFO: + - First manifest request from new client + - Stream switch requests + DEBUG: + - All other requests + + Format: + {client_ip} - {method} {path} + """ if request.path == '/stream.m3u8' and not segment_buffers: # First manifest request from a client logging.info(f"New client connected from {request.remote_addr}") @@ -567,15 +726,317 @@ def log_request_info(): # Configure Werkzeug logger to DEBUG logging.getLogger('werkzeug').setLevel(logging.DEBUG) +class ProxyServer: + """Manages HLS proxy server instance""" + + def __init__(self, user_agent: Optional[str] = None): + self.app = Flask(__name__) + self.stream_managers: Dict[str, StreamManager] = {} + self.stream_buffers: Dict[str, StreamBuffer] = {} + self.client_managers: Dict[str, ClientManager] = {} + self.fetch_threads: Dict[str, threading.Thread] = {} + self.user_agent: str = user_agent or Config.DEFAULT_USER_AGENT + self._setup_routes() + + def _setup_routes(self) -> None: + """Configure Flask routes""" + self.app.add_url_rule( + '/stream/', # Changed from //stream.m3u8 + view_func=self.stream_endpoint + ) + self.app.add_url_rule( + '/stream//segments/', # Updated to match new pattern + view_func=self.get_segment + ) + self.app.add_url_rule( + '/change_stream/', # Changed from //change_stream + view_func=self.change_stream, + methods=['POST'] + ) + + def initialize_channel(self, url: str, channel_id: str) -> None: + """Initialize a new channel stream""" + if channel_id in self.stream_managers: + self.stop_channel(channel_id) + + manager = StreamManager( + url, + channel_id, + user_agent=self.user_agent + ) + buffer = StreamBuffer() + + # Store resources + self.stream_managers[channel_id] = manager + self.stream_buffers[channel_id] = buffer + self.client_managers[channel_id] = ClientManager() + + # Create and store fetcher + fetcher = StreamFetcher(manager, buffer) + manager.fetcher = fetcher # Store reference to fetcher + + # Start fetch thread + self.fetch_threads[channel_id] = threading.Thread( + target=fetcher.fetch_loop, + name=f"StreamFetcher-{channel_id}", + daemon=True + ) + self.fetch_threads[channel_id].start() + + logging.info(f"Initialized channel {channel_id} with URL {url}") + + def stop_channel(self, channel_id: str) -> None: + """Stop and cleanup a channel""" + if channel_id in self.stream_managers: + self.stream_managers[channel_id].stop() + if channel_id in self.fetch_threads: + self.fetch_threads[channel_id].join(timeout=5) + self._cleanup_channel(channel_id) + + def _cleanup_channel(self, channel_id: str) -> None: + """ + Remove all resources associated with a channel. + + Args: + channel_id: Channel to cleanup + + Removes: + - Stream manager instance + - Segment buffer + - Client manager + - Fetch thread reference + + Thread safety: + Should only be called after stream manager is stopped + and fetch thread has completed + """ + + for collection in [self.stream_managers, self.stream_buffers, + self.client_managers, self.fetch_threads]: + collection.pop(channel_id, None) + + def run(self, host: str = '0.0.0.0', port: int = 5000) -> None: + """Start the proxy server""" + try: + self.app.run(host=host, port=port, threaded=True) + except KeyboardInterrupt: + logging.info("Shutting down gracefully...") + self.shutdown() + except Exception as e: + logging.error(f"Server error: {e}") + self.shutdown() + raise + + def shutdown(self) -> None: + """ + Stop all channels and cleanup resources. + + Steps: + 1. Stop all active stream managers + 2. Join fetch threads + 3. Clean up channel resources + 4. Release system resources + + Thread Safety: + Safe to call from signal handlers or during shutdown + """ + for channel_id in list(self.stream_managers.keys()): + self.stop_channel(channel_id) + + def stream_endpoint(self, channel_id: str): + """ + Flask route handler for serving HLS manifests. + + Args: + channel_id: Unique identifier for the stream channel + + Returns: + Flask Response with: + - HLS manifest content + - Proper content type + - 404 if channel/segments not found + + The manifest includes: + - Current segment window + - Proper sequence numbering + - Discontinuity markers + - Accurate segment durations + """ + try: + if channel_id not in self.stream_managers: + return Response('Channel not found', status=404) + + buffer = self.stream_buffers[channel_id] + manager = self.stream_managers[channel_id] + + # Verify buffer state + with buffer.lock: + available = sorted(buffer.keys()) + if not available: + logging.warning("No segments available") + return '', 404 + + max_seq = max(available) + + # Find the first segment after any discontinuity + discontinuity_start = min(available) + for seq in available: + if seq in manager.source_changes: + discontinuity_start = seq + break + + # Calculate window bounds starting from discontinuity + if len(available) <= Config.INITIAL_SEGMENTS: + min_seq = discontinuity_start + else: + min_seq = max( + discontinuity_start, + max_seq - Config.WINDOW_SIZE + 1 + ) + + # Build manifest with proper tags + new_manifest = ['#EXTM3U'] + new_manifest.append(f'#EXT-X-VERSION:{manager.manifest_version}') + new_manifest.append(f'#EXT-X-MEDIA-SEQUENCE:{min_seq}') + new_manifest.append(f'#EXT-X-TARGETDURATION:{int(manager.target_duration)}') + + # Filter segments within window + window_segments = [s for s in available if min_seq <= s <= max_seq] + + # Add segments with discontinuity handling + for seq in window_segments: + if seq in manager.source_changes: + new_manifest.append('#EXT-X-DISCONTINUITY') + logging.debug(f"Added discontinuity marker before segment {seq}") + + duration = manager.segment_durations.get(seq, 10.0) + new_manifest.append(f'#EXTINF:{duration},') + new_manifest.append(f'/stream/{channel_id}/segments/{seq}.ts') + + manifest_content = '\n'.join(new_manifest) + logging.debug(f"Serving manifest with segments {min_seq}-{max_seq} (window: {len(window_segments)})") + return Response(manifest_content, content_type='application/vnd.apple.mpegurl') + except ConnectionAbortedError: + logging.debug("Client disconnected") + return '', 499 + except Exception as e: + logging.error(f"Stream endpoint error: {e}") + return '', 500 + + def get_segment(self, channel_id: str, segment_name: str): + """ + Serve individual MPEG-TS segments to clients. + + Args: + channel_id: Unique identifier for the channel + segment_name: Segment filename (e.g., '123.ts') + + Returns: + Flask Response: + - MPEG-TS segment data with video/MP2T content type + - 404 if segment or channel not found + + Error Handling: + - Logs warning if segment not found + - Logs error on unexpected exceptions + - Returns 404 on any error + """ + if channel_id not in self.stream_managers: + return Response('Channel not found', status=404) + + try: + segment_id = int(segment_name.split('.')[0]) + buffer = self.stream_buffers[channel_id] + + with buffer_lock: + if segment_id in buffer: + return Response(buffer[segment_id], content_type='video/MP2T') + + logging.warning(f"Segment {segment_id} not found for channel {channel_id}") + except Exception as e: + logging.error(f"Error serving segment {segment_name}: {e}") + return '', 404 + + def change_stream(self, channel_id: str): + """ + Handle stream URL changes via POST request. + + Args: + channel_id: Channel to modify + + Expected JSON body: + { + "url": "new_stream_url" + } + + Returns: + JSON response with: + - Success/failure message + - Channel ID + - New/current URL + - HTTP 404 if channel not found + - HTTP 400 if URL missing from request + + Side effects: + - Updates stream manager URL + - Triggers stream switch sequence + - Maintains segment numbering + """ + if channel_id not in self.stream_managers: + return jsonify({'error': 'Channel not found'}), 404 + + new_url = request.json.get('url') + if not new_url: + return jsonify({'error': 'No URL provided'}), 400 + + manager = self.stream_managers[channel_id] + if manager.update_url(new_url): + return jsonify({ + 'message': 'Stream URL updated', + 'channel': channel_id, + 'url': new_url + }) + return jsonify({ + 'message': 'URL unchanged', + 'channel': channel_id, + 'url': new_url + }) + + @app.before_request + def log_request_info(): + """ + Log client connections and important requests. + + Log Levels: + INFO: + - First manifest request from new client + - Stream switch requests + DEBUG: + - Segment requests + - Routine manifest updates + + Format: + "{client_ip} - {method} {path}" + + Side Effects: + - Updates logging configuration based on request type + - Tracks client connections + """ + # Main Application Setup if __name__ == '__main__': # Command line argument parsing parser = argparse.ArgumentParser(description='HLS Proxy Server with Stream Switching') parser.add_argument( '--url', '-u', - default='http://example.com/stream.m3u8', + required=True, help='Initial HLS stream URL to proxy' ) + parser.add_argument( + '--channel', '-c', + required=True, + help='Channel ID for the stream (default: default)' + ) parser.add_argument( '--port', '-p', type=int, @@ -587,6 +1048,10 @@ if __name__ == '__main__': default='0.0.0.0', help='Interface to bind server to (default: all interfaces)' ) + parser.add_argument( + '--user-agent', '-ua', + help='Custom User-Agent string' + ) parser.add_argument( '--debug', action='store_true', @@ -594,34 +1059,30 @@ if __name__ == '__main__': ) args = parser.parse_args() - # Configure logging with separate format for access logs + # Configure logging logging.basicConfig( level=logging.DEBUG if args.debug else logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) - # Initialize proxy components try: - # Create and start stream manager - stream_manager = StreamManager(args.url) - stream_manager.start() + # Initialize proxy server + proxy = ProxyServer(user_agent=args.user_agent) + + # Initialize channel with provided URL + proxy.initialize_channel(args.url, args.channel) logging.info(f"Starting HLS proxy server on {args.host}:{args.port}") logging.info(f"Initial stream URL: {args.url}") + logging.info(f"Channel ID: {args.channel}") # Run Flask development server - # Note: For production, use a proper WSGI server like gunicorn - app.run( - host=args.host, - port=args.port, - threaded=True, # Enable multi-threading for segment handling - debug=args.debug - ) + proxy.run(host=args.host, port=args.port) + except Exception as e: logging.error(f"Failed to start server: {e}") - if stream_manager: - stream_manager.running = False - if stream_manager.fetch_thread: - stream_manager.fetch_thread.join() sys.exit(1) + finally: + if 'proxy' in locals(): + proxy.shutdown()