mirror of
https://github.com/Dispatcharr/Dispatcharr.git
synced 2026-01-23 10:45:27 +00:00
1088 lines
39 KiB
Text
1088 lines
39 KiB
Text
"""
|
|
HLS Proxy Server with Advanced Stream Switching Support
|
|
This proxy handles HLS live streams with support for:
|
|
- Stream switching with proper discontinuity handling
|
|
- Buffer management
|
|
- Segment validation
|
|
- Connection pooling and reuse
|
|
"""
|
|
|
|
from flask import Flask, Response, request, jsonify
|
|
import requests
|
|
import threading
|
|
import logging
|
|
import m3u8
|
|
import time
|
|
from urllib.parse import urlparse, urljoin
|
|
import argparse
|
|
from typing import Optional, Dict, List, Set, Deque
|
|
import sys
|
|
import os
|
|
|
|
# Initialize Flask app
|
|
app = Flask(__name__)
|
|
|
|
# Global state management
|
|
manifest_buffer = None # Stores current manifest content
|
|
segment_buffers = {} # Maps sequence numbers to segment data
|
|
buffer_lock = threading.Lock() # Synchronizes access to buffers
|
|
|
|
class Config:
|
|
"""Configuration settings for stream handling and buffering"""
|
|
# Buffer size settings
|
|
MIN_SEGMENTS = 12 # Minimum segments to maintain
|
|
MAX_SEGMENTS = 16 # Maximum segments to store
|
|
WINDOW_SIZE = 12 # Number of segments in manifest window
|
|
INITIAL_SEGMENTS = 3 # Initial segments to buffer before playback
|
|
DEFAULT_USER_AGENT = 'VLC/3.0.20 LibVLC/3.0.20'
|
|
|
|
class StreamBuffer:
|
|
"""
|
|
Manages buffering of stream segments with thread-safe access.
|
|
|
|
Attributes:
|
|
buffer (Dict[int, bytes]): Maps sequence numbers to segment data
|
|
lock (threading.Lock): Thread safety for buffer access
|
|
|
|
Features:
|
|
- Thread-safe segment storage and retrieval
|
|
- Automatic cleanup of old segments
|
|
- Sequence number based indexing
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.buffer: Dict[int, bytes] = {} # Maps sequence numbers to segment data
|
|
self.lock: threading.Lock = threading.Lock()
|
|
|
|
def __getitem__(self, key: int) -> Optional[bytes]:
|
|
"""Get segment data by sequence number"""
|
|
return self.buffer.get(key)
|
|
|
|
def __setitem__(self, key: int, value: bytes):
|
|
"""Store segment data by sequence number"""
|
|
self.buffer[key] = value
|
|
# Cleanup old segments if we exceed MAX_SEGMENTS
|
|
if len(self.buffer) > Config.MAX_SEGMENTS:
|
|
keys = sorted(self.buffer.keys())
|
|
# Keep the most recent MAX_SEGMENTS
|
|
to_remove = keys[:-Config.MAX_SEGMENTS]
|
|
for k in to_remove:
|
|
del self.buffer[k]
|
|
|
|
def __contains__(self, key: int) -> bool:
|
|
"""Check if sequence number exists in buffer"""
|
|
return key in self.buffer
|
|
|
|
def keys(self) -> List[int]:
|
|
"""Get list of available sequence numbers"""
|
|
return list(self.buffer.keys())
|
|
|
|
def cleanup(self, keep_sequences: List[int]):
|
|
"""Remove segments not in keep list"""
|
|
for seq in list(self.buffer.keys()):
|
|
if seq not in keep_sequences:
|
|
del self.buffer[seq]
|
|
|
|
class ClientManager:
|
|
"""Manages client connections"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
Initialize client connection tracking.
|
|
|
|
Attributes:
|
|
active_clients (Set[int]): Set of active client IDs
|
|
lock (threading.Lock): Thread safety for client operations
|
|
"""
|
|
self.active_clients: Set[int] = set()
|
|
self.lock: threading.Lock = threading.Lock()
|
|
|
|
def add_client(self, client_id: int):
|
|
"""Add new client connection"""
|
|
with self.lock:
|
|
self.active_clients.add(client_id)
|
|
logging.info(f"New client connected (total: {len(self.active_clients)})")
|
|
|
|
def remove_client(self, client_id: int) -> int:
|
|
"""Remove client and return remaining count"""
|
|
with self.lock:
|
|
self.active_clients.remove(client_id)
|
|
remaining = len(self.active_clients)
|
|
logging.info(f"Client disconnected (remaining: {remaining})")
|
|
return remaining
|
|
|
|
def has_clients(self) -> bool:
|
|
"""Check if any clients are connected"""
|
|
return len(self.active_clients) > 0
|
|
|
|
class StreamManager:
|
|
"""
|
|
Manages HLS stream state and switching logic.
|
|
|
|
Attributes:
|
|
current_url (str): Current stream URL
|
|
channel_id (str): Unique channel identifier
|
|
running (bool): Stream activity flag
|
|
next_sequence (int): Next sequence number to assign
|
|
buffered_sequences (set): Currently buffered sequence numbers
|
|
source_changes (set): Sequences where stream source changed
|
|
|
|
Features:
|
|
- Stream URL management
|
|
- Sequence number assignment
|
|
- Discontinuity tracking
|
|
- Thread coordination
|
|
- Buffer state management
|
|
"""
|
|
def __init__(self, initial_url: str, channel_id: str, user_agent: Optional[str] = None):
|
|
# Stream state
|
|
self.current_url = initial_url
|
|
self.channel_id = channel_id
|
|
self.user_agent = user_agent or Config.DEFAULT_USER_AGENT
|
|
self.running = True
|
|
self.switching_stream = False
|
|
|
|
# Sequence tracking
|
|
self.next_sequence = 0
|
|
self.highest_sequence = 0
|
|
self.buffered_sequences = set()
|
|
self.downloaded_sources = {}
|
|
self.segment_durations = {}
|
|
|
|
# Source tracking
|
|
self.current_source = None
|
|
self.source_changes = set()
|
|
self.stream_switch_count = 0
|
|
|
|
# Threading
|
|
self.fetcher = None
|
|
self.fetch_thread = None
|
|
self.url_changed = threading.Event()
|
|
|
|
# Add manifest info
|
|
self.target_duration = 10.0 # Default, will be updated from manifest
|
|
self.manifest_version = 3 # Default, will be updated from manifest
|
|
|
|
logging.info(f"Initialized stream manager for channel {channel_id}")
|
|
|
|
def update_url(self, new_url: str) -> bool:
|
|
"""
|
|
Handle stream URL changes with proper discontinuity marking.
|
|
|
|
Args:
|
|
new_url: New stream URL to switch to
|
|
|
|
Returns:
|
|
bool: True if URL changed, False if unchanged
|
|
|
|
Side effects:
|
|
- Sets switching_stream flag
|
|
- Updates current_url
|
|
- Maintains sequence numbering
|
|
- Marks discontinuity point
|
|
- Signals fetch thread
|
|
"""
|
|
if new_url != self.current_url:
|
|
with buffer_lock:
|
|
self.switching_stream = True
|
|
self.current_url = new_url
|
|
|
|
# Continue sequence numbering from last sequence
|
|
if self.buffered_sequences:
|
|
self.next_sequence = max(self.buffered_sequences) + 1
|
|
|
|
# Mark discontinuity at next sequence
|
|
self.source_changes.add(self.next_sequence)
|
|
|
|
logging.info(f"Stream switch - next sequence will start at {self.next_sequence}")
|
|
|
|
# Clear state but maintain sequence numbers
|
|
self.downloaded_sources.clear()
|
|
self.segment_durations.clear()
|
|
self.current_source = None
|
|
|
|
# Signal thread to switch URL
|
|
self.url_changed.set()
|
|
|
|
return True
|
|
return False
|
|
|
|
def get_next_sequence(self, source_id: str) -> Optional[int]:
|
|
"""
|
|
Assign sequence numbers to segments with source change detection.
|
|
|
|
Args:
|
|
source_id: Unique identifier for segment source
|
|
|
|
Returns:
|
|
int: Next available sequence number
|
|
None: If segment already downloaded
|
|
|
|
Side effects:
|
|
- Updates buffered sequences set
|
|
- Tracks source changes for discontinuity
|
|
- Maintains sequence numbering
|
|
"""
|
|
if source_id in self.downloaded_sources:
|
|
return None
|
|
|
|
seq = self.next_sequence
|
|
while (seq in self.buffered_sequences):
|
|
seq += 1
|
|
|
|
# Track source changes for discontinuity markers
|
|
source_prefix = source_id.split('_')[0]
|
|
if not self.switching_stream and self.current_source and self.current_source != source_prefix:
|
|
self.source_changes.add(seq)
|
|
logging.debug(f"Source change detected at sequence {seq}")
|
|
self.current_source = source_prefix
|
|
|
|
# Update tracking
|
|
self.downloaded_sources[source_id] = seq
|
|
self.buffered_sequences.add(seq)
|
|
self.next_sequence = seq + 1
|
|
self.highest_sequence = max(self.highest_sequence, seq)
|
|
|
|
return seq
|
|
|
|
def _fetch_loop(self):
|
|
"""Background thread for continuous stream fetching"""
|
|
while self.running:
|
|
try:
|
|
fetcher = StreamFetcher(self, self.buffer)
|
|
fetch_stream(fetcher, self.url_changed, self.next_sequence)
|
|
except Exception as e:
|
|
logging.error(f"Stream error: {e}")
|
|
time.sleep(5) # Wait before retry
|
|
|
|
self.url_changed.clear()
|
|
|
|
def start(self):
|
|
"""Start the background fetch thread"""
|
|
if not self.fetch_thread or not self.fetch_thread.is_alive():
|
|
self.running = True
|
|
self.fetch_thread = threading.Thread(
|
|
target=self._fetch_loop,
|
|
name="StreamFetcher",
|
|
daemon=True # Thread will exit when main program does
|
|
)
|
|
self.fetch_thread.start()
|
|
logging.info("Stream manager started")
|
|
|
|
def stop(self):
|
|
"""Stop the background fetch thread"""
|
|
self.running = False
|
|
if self.fetch_thread and self.fetch_thread.is_alive():
|
|
self.url_changed.set() # Signal thread to exit
|
|
self.fetch_thread.join(timeout=5) # Wait up to 5 seconds
|
|
logging.info("Stream manager stopped")
|
|
|
|
class StreamFetcher:
|
|
"""
|
|
Handles HTTP requests for stream segments with connection pooling.
|
|
|
|
Attributes:
|
|
manager (StreamManager): Associated stream manager instance
|
|
buffer (StreamBuffer): Buffer for storing segments
|
|
session (requests.Session): Persistent HTTP session
|
|
redirect_cache (dict): Cache for redirect responses
|
|
|
|
Features:
|
|
- Connection pooling and reuse
|
|
- Redirect caching
|
|
- Rate limiting
|
|
- Automatic retries
|
|
- Host fallback
|
|
"""
|
|
def __init__(self, manager: StreamManager, buffer: StreamBuffer):
|
|
self.manager = manager
|
|
self.buffer = buffer
|
|
self.stream_url = manager.current_url
|
|
self.session = requests.Session()
|
|
|
|
# Configure session headers
|
|
self.session.headers.update({
|
|
'User-Agent': manager.user_agent,
|
|
'Connection': 'keep-alive'
|
|
})
|
|
|
|
# Set up connection pooling
|
|
adapter = requests.adapters.HTTPAdapter(
|
|
pool_connections=2, # Number of connection pools
|
|
pool_maxsize=4, # Connections per pool
|
|
max_retries=3, # Auto-retry failed requests
|
|
pool_block=False # Don't block when pool is full
|
|
)
|
|
|
|
# Apply adapter to both HTTP and HTTPS
|
|
self.session.mount('http://', adapter)
|
|
self.session.mount('https://', adapter)
|
|
|
|
# Request optimization
|
|
self.last_request_time = 0
|
|
self.min_request_interval = 0.05 # Minimum time between requests
|
|
self.last_host = None # Cache last successful host
|
|
self.redirect_cache = {} # Cache redirect responses
|
|
self.redirect_cache_limit = 1000
|
|
|
|
def cleanup_redirect_cache(self):
|
|
"""Remove old redirect cache entries"""
|
|
if len(self.redirect_cache) > self.redirect_cache_limit:
|
|
self.redirect_cache.clear()
|
|
|
|
def get_base_host(self, url: str) -> str:
|
|
"""
|
|
Extract base host from URL.
|
|
|
|
Args:
|
|
url: Full URL to parse
|
|
|
|
Returns:
|
|
str: Base host in format 'scheme://hostname'
|
|
|
|
Example:
|
|
'http://example.com/path' -> 'http://example.com'
|
|
"""
|
|
try:
|
|
parsed = urlparse(url)
|
|
return f"{parsed.scheme}://{parsed.netloc}"
|
|
except Exception as e:
|
|
logging.error(f"Error extracting base host: {e}")
|
|
return url
|
|
|
|
def download(self, url: str) -> tuple[bytes, str]:
|
|
"""
|
|
Download content with connection reuse and redirect handling.
|
|
|
|
Args:
|
|
url: URL to download from
|
|
|
|
Returns:
|
|
tuple containing:
|
|
bytes: Downloaded content
|
|
str: Final URL after any redirects
|
|
|
|
Features:
|
|
- Connection pooling/reuse
|
|
- Redirect caching
|
|
- Rate limiting
|
|
- Host fallback on failure
|
|
- Automatic retries
|
|
"""
|
|
now = time.time()
|
|
wait_time = self.last_request_time + self.min_request_interval - now
|
|
if (wait_time > 0):
|
|
time.sleep(wait_time)
|
|
|
|
try:
|
|
# Use cached redirect if available
|
|
if url in self.redirect_cache:
|
|
logging.debug(f"Using cached redirect for {url}")
|
|
final_url = self.redirect_cache[url]
|
|
response = self.session.get(final_url, timeout=10)
|
|
else:
|
|
response = self.session.get(url, allow_redirects=True, timeout=10)
|
|
if response.history: # Cache redirects
|
|
logging.debug(f"Caching redirect for {url} -> {response.url}")
|
|
self.redirect_cache[url] = response.url
|
|
|
|
self.last_request_time = time.time()
|
|
|
|
if response.status_code == 200:
|
|
self.last_host = self.get_base_host(response.url)
|
|
|
|
return response.content, response.url
|
|
|
|
except Exception as e:
|
|
logging.error(f"Download error: {e}")
|
|
if self.last_host and not url.startswith(self.last_host):
|
|
# Use urljoin to handle path resolution
|
|
new_url = urljoin(self.last_host + '/', url.split('://')[-1].split('/', 1)[-1])
|
|
logging.debug(f"Retrying with last host: {new_url}")
|
|
return self.download(new_url)
|
|
raise
|
|
|
|
def fetch_loop(self):
|
|
"""
|
|
Main fetching loop that continuously downloads stream content.
|
|
|
|
Features:
|
|
- Automatic manifest updates
|
|
- Rate-limited downloads
|
|
- Exponential backoff on errors
|
|
- Stream switch handling
|
|
- Segment validation
|
|
|
|
Error handling:
|
|
- HTTP 509 rate limiting
|
|
- Connection drops
|
|
- Invalid segments
|
|
- Manifest parsing errors
|
|
|
|
Thread safety:
|
|
Coordinates with StreamManager and StreamBuffer
|
|
using proper locking mechanisms
|
|
"""
|
|
retry_delay = 1
|
|
max_retry_delay = 8
|
|
last_manifest_time = 0
|
|
|
|
while self.manager.running:
|
|
try:
|
|
now = time.time()
|
|
|
|
# Get manifest data
|
|
try:
|
|
manifest_data, final_url = self.download(self.manager.current_url)
|
|
manifest = m3u8.loads(manifest_data.decode())
|
|
|
|
# Reset retry delay on successful fetch
|
|
retry_delay = 1
|
|
|
|
except requests.exceptions.HTTPError as e:
|
|
if e.response.status_code == 509:
|
|
logging.warning("Rate limit exceeded, backing off...")
|
|
time.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, max_retry_delay)
|
|
continue
|
|
raise
|
|
|
|
# Update manifest info
|
|
if manifest.target_duration:
|
|
self.manager.target_duration = float(manifest.target_duration)
|
|
if manifest.version:
|
|
self.manager.manifest_version = manifest.version
|
|
|
|
if not manifest.segments:
|
|
logging.warning("No segments in manifest")
|
|
time.sleep(retry_delay)
|
|
continue
|
|
|
|
# Calculate proper manifest polling interval
|
|
target_duration = float(manifest.target_duration)
|
|
manifest_interval = target_duration * 0.5 # Poll at half the segment duration
|
|
|
|
# Process latest segment
|
|
latest_segment = manifest.segments[-1]
|
|
try:
|
|
segment_url = urljoin(final_url, latest_segment.uri)
|
|
segment_data, _ = self.download(segment_url)
|
|
|
|
verification = verify_segment(segment_data)
|
|
if not verification.get('valid', False):
|
|
logging.warning(f"Invalid segment: {verification.get('error')}")
|
|
continue
|
|
|
|
# Store segment with proper locking
|
|
with self.buffer.lock:
|
|
seq = self.manager.next_sequence
|
|
self.buffer[seq] = segment_data
|
|
self.manager.segment_durations[seq] = float(latest_segment.duration)
|
|
self.manager.next_sequence += 1
|
|
logging.debug(f"Stored segment {seq} (duration: {latest_segment.duration}s)")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Segment download error: {e}")
|
|
continue
|
|
|
|
# Update last manifest time and wait for next interval
|
|
last_manifest_time = now
|
|
time.sleep(manifest_interval)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Fetch error: {e}")
|
|
time.sleep(retry_delay)
|
|
retry_delay = min(retry_delay * 2, max_retry_delay)
|
|
|
|
def get_segment_sequence(segment_uri: str) -> Optional[int]:
|
|
"""
|
|
Extract sequence number from segment URI pattern.
|
|
|
|
Args:
|
|
segment_uri: Segment filename or path
|
|
|
|
Returns:
|
|
int: Extracted sequence number if found
|
|
None: If no valid sequence number can be extracted
|
|
|
|
Handles common patterns like:
|
|
- Numerical sequences (e.g., segment_1234.ts)
|
|
- Complex patterns with stream IDs (e.g., stream_123_456.ts)
|
|
"""
|
|
|
|
try:
|
|
# Try numerical sequence (e.g., 1038_3693.ts)
|
|
if '_' in segment_uri:
|
|
return int(segment_uri.split('_')[-1].split('.')[0])
|
|
return None
|
|
except ValueError:
|
|
return None
|
|
|
|
# Update verify_segment with more thorough checks
|
|
def verify_segment(data: bytes) -> dict:
|
|
"""
|
|
Verify MPEG-TS segment integrity and structure.
|
|
|
|
Args:
|
|
data: Raw segment data bytes
|
|
|
|
Returns:
|
|
dict containing:
|
|
valid (bool): True if segment passes all checks
|
|
packets (int): Number of valid packets found
|
|
size (int): Total segment size in bytes
|
|
error (str): Description if validation fails
|
|
|
|
Checks:
|
|
- Minimum size requirements
|
|
- Packet size alignment
|
|
- Sync byte presence
|
|
- Transport error indicators
|
|
"""
|
|
|
|
# Check minimum size
|
|
if len(data) < 188:
|
|
return {'valid': False, 'error': 'Segment too short'}
|
|
|
|
# Verify segment size is multiple of packet size
|
|
if len(data) % 188 != 0:
|
|
return {'valid': False, 'error': 'Invalid segment size'}
|
|
|
|
valid_packets = 0
|
|
total_packets = len(data) // 188
|
|
|
|
# Scan all packets in segment
|
|
for i in range(0, len(data), 188):
|
|
packet = data[i:i+188]
|
|
|
|
# Check packet completeness
|
|
if len(packet) != 188:
|
|
return {'valid': False, 'error': 'Incomplete packet'}
|
|
|
|
# Verify sync byte
|
|
if packet[0] != 0x47:
|
|
return {'valid': False, 'error': f'Invalid sync byte at offset {i}'}
|
|
|
|
# Check transport error indicator
|
|
if packet[1] & 0x80:
|
|
return {'valid': False, 'error': 'Transport error indicator set'}
|
|
|
|
valid_packets += 1
|
|
|
|
return {
|
|
'valid': True,
|
|
'packets': valid_packets,
|
|
'size': len(data)
|
|
}
|
|
|
|
def fetch_stream(fetcher: StreamFetcher, stop_event: threading.Event, start_sequence: int = 0):
|
|
"""
|
|
Main streaming function that handles manifest updates and segment downloads.
|
|
|
|
Args:
|
|
fetcher: StreamFetcher instance to handle HTTP requests
|
|
stop_event: Threading event to signal when to stop fetching
|
|
start_sequence: Initial sequence number to start from
|
|
|
|
The function implements the core HLS fetching logic:
|
|
- Fetches and parses manifest files
|
|
- Downloads new segments when they become available
|
|
- Handles stream switches with proper discontinuity marking
|
|
- Maintains buffer state and segment sequence numbering
|
|
"""
|
|
# Remove global stream_manager reference
|
|
retry_delay = 1
|
|
max_retry_delay = 8
|
|
last_segment_time = 0
|
|
buffer_initialized = False
|
|
manifest_update_needed = True
|
|
segment_duration = None
|
|
|
|
while not stop_event.is_set():
|
|
try:
|
|
now = time.time()
|
|
|
|
# Only update manifest when it's time for next segment
|
|
should_update = (
|
|
manifest_update_needed or
|
|
not segment_duration or
|
|
(last_segment_time and (now - last_segment_time) >= segment_duration * 0.8)
|
|
)
|
|
|
|
if should_update:
|
|
manifest_data, final_url = fetcher.download(fetcher.stream_url)
|
|
manifest = m3u8.loads(manifest_data.decode())
|
|
|
|
if not manifest.segments:
|
|
continue
|
|
|
|
with buffer_lock:
|
|
manifest_content = manifest_data.decode()
|
|
new_segments = {}
|
|
|
|
if fetcher.manager.switching_stream: # Use fetcher.manager instead of stream_manager
|
|
# Stream switch - only get latest segment
|
|
manifest_segments = [manifest.segments[-1]]
|
|
seq_start = fetcher.manager.next_sequence
|
|
max_segments = 1
|
|
logging.debug(f"Processing stream switch - getting latest segment at sequence {seq_start}")
|
|
elif not buffer_initialized:
|
|
# Initial buffer
|
|
manifest_segments = manifest.segments[-Config.INITIAL_SEGMENTS:]
|
|
seq_start = fetcher.manager.next_sequence
|
|
max_segments = Config.INITIAL_SEGMENTS
|
|
logging.debug(f"Starting initial buffer at sequence {seq_start}")
|
|
else:
|
|
# Normal operation
|
|
manifest_segments = [manifest.segments[-1]]
|
|
seq_start = fetcher.manager.next_sequence
|
|
max_segments = 1
|
|
|
|
# Map segments
|
|
segments_mapped = 0
|
|
for segment in manifest_segments:
|
|
if segments_mapped >= max_segments:
|
|
break
|
|
|
|
source_id = segment.uri.split('/')[-1].split('.')[0]
|
|
next_seq = fetcher.manager.get_next_sequence(source_id)
|
|
|
|
if next_seq is not None:
|
|
duration = float(segment.duration)
|
|
new_segments[next_seq] = {
|
|
'uri': segment.uri,
|
|
'duration': duration,
|
|
'source_id': source_id
|
|
}
|
|
fetcher.manager.segment_durations[next_seq] = duration
|
|
segments_mapped += 1
|
|
|
|
manifest_buffer = manifest_content
|
|
|
|
# Download segments
|
|
for sequence_id, segment_info in new_segments.items():
|
|
try:
|
|
segment_url = f"{fetcher.last_host}{segment_info['uri']}"
|
|
logging.debug(f"Downloading {segment_info['uri']} as segment {sequence_id}.ts "
|
|
f"(source: {segment_info['source_id']}, duration: {segment_info['duration']:.3f}s)")
|
|
|
|
segment_data, _ = fetcher.download(segment_url)
|
|
validation = verify_segment(segment_data)
|
|
|
|
if validation.get('valid', False):
|
|
with buffer_lock:
|
|
segment_buffers[sequence_id] = segment_data
|
|
logging.debug(f"Downloaded and verified segment {sequence_id} (packets: {validation['packets']})")
|
|
|
|
if fetcher.manager.switching_stream:
|
|
fetcher.manager.switching_stream = False
|
|
stop_event.set() # Force fetcher restart with new URL
|
|
break
|
|
elif not buffer_initialized and len(segment_buffers) >= Config.INITIAL_SEGMENTS:
|
|
buffer_initialized = True
|
|
manifest_update_needed = True
|
|
break
|
|
except Exception as e:
|
|
logging.error(f"Segment download error: {e}")
|
|
continue
|
|
|
|
else:
|
|
# Short sleep to prevent CPU spinning
|
|
threading.Event().wait(0.1)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Manifest error: {e}")
|
|
threading.Event().wait(retry_delay)
|
|
retry_delay = min(retry_delay * 2, max_retry_delay)
|
|
manifest_update_needed = True
|
|
|
|
|
|
|
|
@app.before_request
|
|
def log_request_info():
|
|
"""
|
|
Log client connections and important requests.
|
|
|
|
Logs:
|
|
INFO:
|
|
- First manifest request from new client
|
|
- Stream switch requests
|
|
DEBUG:
|
|
- All other requests
|
|
|
|
Format:
|
|
{client_ip} - {method} {path}
|
|
"""
|
|
if request.path == '/stream.m3u8' and not segment_buffers:
|
|
# First manifest request from a client
|
|
logging.info(f"New client connected from {request.remote_addr}")
|
|
elif request.path.startswith('/change_stream'):
|
|
# Keep stream switch requests as INFO
|
|
logging.info(f"Stream switch requested from {request.remote_addr}")
|
|
else:
|
|
# Move routine requests to DEBUG
|
|
logging.debug(f"{request.remote_addr} - {request.method} {request.path}")
|
|
|
|
# Configure Werkzeug logger to DEBUG
|
|
logging.getLogger('werkzeug').setLevel(logging.DEBUG)
|
|
|
|
class ProxyServer:
|
|
"""Manages HLS proxy server instance"""
|
|
|
|
def __init__(self, user_agent: Optional[str] = None):
|
|
self.app = Flask(__name__)
|
|
self.stream_managers: Dict[str, StreamManager] = {}
|
|
self.stream_buffers: Dict[str, StreamBuffer] = {}
|
|
self.client_managers: Dict[str, ClientManager] = {}
|
|
self.fetch_threads: Dict[str, threading.Thread] = {}
|
|
self.user_agent: str = user_agent or Config.DEFAULT_USER_AGENT
|
|
self._setup_routes()
|
|
|
|
def _setup_routes(self) -> None:
|
|
"""Configure Flask routes"""
|
|
self.app.add_url_rule(
|
|
'/stream/<channel_id>', # Changed from /<channel_id>/stream.m3u8
|
|
view_func=self.stream_endpoint
|
|
)
|
|
self.app.add_url_rule(
|
|
'/stream/<channel_id>/segments/<path:segment_name>', # Updated to match new pattern
|
|
view_func=self.get_segment
|
|
)
|
|
self.app.add_url_rule(
|
|
'/change_stream/<channel_id>', # Changed from /<channel_id>/change_stream
|
|
view_func=self.change_stream,
|
|
methods=['POST']
|
|
)
|
|
|
|
def initialize_channel(self, url: str, channel_id: str) -> None:
|
|
"""Initialize a new channel stream"""
|
|
if channel_id in self.stream_managers:
|
|
self.stop_channel(channel_id)
|
|
|
|
manager = StreamManager(
|
|
url,
|
|
channel_id,
|
|
user_agent=self.user_agent
|
|
)
|
|
buffer = StreamBuffer()
|
|
|
|
# Store resources
|
|
self.stream_managers[channel_id] = manager
|
|
self.stream_buffers[channel_id] = buffer
|
|
self.client_managers[channel_id] = ClientManager()
|
|
|
|
# Create and store fetcher
|
|
fetcher = StreamFetcher(manager, buffer)
|
|
manager.fetcher = fetcher # Store reference to fetcher
|
|
|
|
# Start fetch thread
|
|
self.fetch_threads[channel_id] = threading.Thread(
|
|
target=fetcher.fetch_loop,
|
|
name=f"StreamFetcher-{channel_id}",
|
|
daemon=True
|
|
)
|
|
self.fetch_threads[channel_id].start()
|
|
|
|
logging.info(f"Initialized channel {channel_id} with URL {url}")
|
|
|
|
def stop_channel(self, channel_id: str) -> None:
|
|
"""Stop and cleanup a channel"""
|
|
if channel_id in self.stream_managers:
|
|
self.stream_managers[channel_id].stop()
|
|
if channel_id in self.fetch_threads:
|
|
self.fetch_threads[channel_id].join(timeout=5)
|
|
self._cleanup_channel(channel_id)
|
|
|
|
def _cleanup_channel(self, channel_id: str) -> None:
|
|
"""
|
|
Remove all resources associated with a channel.
|
|
|
|
Args:
|
|
channel_id: Channel to cleanup
|
|
|
|
Removes:
|
|
- Stream manager instance
|
|
- Segment buffer
|
|
- Client manager
|
|
- Fetch thread reference
|
|
|
|
Thread safety:
|
|
Should only be called after stream manager is stopped
|
|
and fetch thread has completed
|
|
"""
|
|
|
|
for collection in [self.stream_managers, self.stream_buffers,
|
|
self.client_managers, self.fetch_threads]:
|
|
collection.pop(channel_id, None)
|
|
|
|
def run(self, host: str = '0.0.0.0', port: int = 5000) -> None:
|
|
"""Start the proxy server"""
|
|
try:
|
|
self.app.run(host=host, port=port, threaded=True)
|
|
except KeyboardInterrupt:
|
|
logging.info("Shutting down gracefully...")
|
|
self.shutdown()
|
|
except Exception as e:
|
|
logging.error(f"Server error: {e}")
|
|
self.shutdown()
|
|
raise
|
|
|
|
def shutdown(self) -> None:
|
|
"""
|
|
Stop all channels and cleanup resources.
|
|
|
|
Steps:
|
|
1. Stop all active stream managers
|
|
2. Join fetch threads
|
|
3. Clean up channel resources
|
|
4. Release system resources
|
|
|
|
Thread Safety:
|
|
Safe to call from signal handlers or during shutdown
|
|
"""
|
|
for channel_id in list(self.stream_managers.keys()):
|
|
self.stop_channel(channel_id)
|
|
|
|
def stream_endpoint(self, channel_id: str):
|
|
"""
|
|
Flask route handler for serving HLS manifests.
|
|
|
|
Args:
|
|
channel_id: Unique identifier for the stream channel
|
|
|
|
Returns:
|
|
Flask Response with:
|
|
- HLS manifest content
|
|
- Proper content type
|
|
- 404 if channel/segments not found
|
|
|
|
The manifest includes:
|
|
- Current segment window
|
|
- Proper sequence numbering
|
|
- Discontinuity markers
|
|
- Accurate segment durations
|
|
"""
|
|
try:
|
|
if channel_id not in self.stream_managers:
|
|
return Response('Channel not found', status=404)
|
|
|
|
buffer = self.stream_buffers[channel_id]
|
|
manager = self.stream_managers[channel_id]
|
|
|
|
# Verify buffer state
|
|
with buffer.lock:
|
|
available = sorted(buffer.keys())
|
|
if not available:
|
|
logging.warning("No segments available")
|
|
return '', 404
|
|
|
|
max_seq = max(available)
|
|
|
|
# Find the first segment after any discontinuity
|
|
discontinuity_start = min(available)
|
|
for seq in available:
|
|
if seq in manager.source_changes:
|
|
discontinuity_start = seq
|
|
break
|
|
|
|
# Calculate window bounds starting from discontinuity
|
|
if len(available) <= Config.INITIAL_SEGMENTS:
|
|
min_seq = discontinuity_start
|
|
else:
|
|
min_seq = max(
|
|
discontinuity_start,
|
|
max_seq - Config.WINDOW_SIZE + 1
|
|
)
|
|
|
|
# Build manifest with proper tags
|
|
new_manifest = ['#EXTM3U']
|
|
new_manifest.append(f'#EXT-X-VERSION:{manager.manifest_version}')
|
|
new_manifest.append(f'#EXT-X-MEDIA-SEQUENCE:{min_seq}')
|
|
new_manifest.append(f'#EXT-X-TARGETDURATION:{int(manager.target_duration)}')
|
|
|
|
# Filter segments within window
|
|
window_segments = [s for s in available if min_seq <= s <= max_seq]
|
|
|
|
# Add segments with discontinuity handling
|
|
for seq in window_segments:
|
|
if seq in manager.source_changes:
|
|
new_manifest.append('#EXT-X-DISCONTINUITY')
|
|
logging.debug(f"Added discontinuity marker before segment {seq}")
|
|
|
|
duration = manager.segment_durations.get(seq, 10.0)
|
|
new_manifest.append(f'#EXTINF:{duration},')
|
|
new_manifest.append(f'/stream/{channel_id}/segments/{seq}.ts')
|
|
|
|
manifest_content = '\n'.join(new_manifest)
|
|
logging.debug(f"Serving manifest with segments {min_seq}-{max_seq} (window: {len(window_segments)})")
|
|
return Response(manifest_content, content_type='application/vnd.apple.mpegurl')
|
|
except ConnectionAbortedError:
|
|
logging.debug("Client disconnected")
|
|
return '', 499
|
|
except Exception as e:
|
|
logging.error(f"Stream endpoint error: {e}")
|
|
return '', 500
|
|
|
|
def get_segment(self, channel_id: str, segment_name: str):
|
|
"""
|
|
Serve individual MPEG-TS segments to clients.
|
|
|
|
Args:
|
|
channel_id: Unique identifier for the channel
|
|
segment_name: Segment filename (e.g., '123.ts')
|
|
|
|
Returns:
|
|
Flask Response:
|
|
- MPEG-TS segment data with video/MP2T content type
|
|
- 404 if segment or channel not found
|
|
|
|
Error Handling:
|
|
- Logs warning if segment not found
|
|
- Logs error on unexpected exceptions
|
|
- Returns 404 on any error
|
|
"""
|
|
if channel_id not in self.stream_managers:
|
|
return Response('Channel not found', status=404)
|
|
|
|
try:
|
|
segment_id = int(segment_name.split('.')[0])
|
|
buffer = self.stream_buffers[channel_id]
|
|
|
|
with buffer_lock:
|
|
if segment_id in buffer:
|
|
return Response(buffer[segment_id], content_type='video/MP2T')
|
|
|
|
logging.warning(f"Segment {segment_id} not found for channel {channel_id}")
|
|
except Exception as e:
|
|
logging.error(f"Error serving segment {segment_name}: {e}")
|
|
return '', 404
|
|
|
|
def change_stream(self, channel_id: str):
|
|
"""
|
|
Handle stream URL changes via POST request.
|
|
|
|
Args:
|
|
channel_id: Channel to modify
|
|
|
|
Expected JSON body:
|
|
{
|
|
"url": "new_stream_url"
|
|
}
|
|
|
|
Returns:
|
|
JSON response with:
|
|
- Success/failure message
|
|
- Channel ID
|
|
- New/current URL
|
|
- HTTP 404 if channel not found
|
|
- HTTP 400 if URL missing from request
|
|
|
|
Side effects:
|
|
- Updates stream manager URL
|
|
- Triggers stream switch sequence
|
|
- Maintains segment numbering
|
|
"""
|
|
if channel_id not in self.stream_managers:
|
|
return jsonify({'error': 'Channel not found'}), 404
|
|
|
|
new_url = request.json.get('url')
|
|
if not new_url:
|
|
return jsonify({'error': 'No URL provided'}), 400
|
|
|
|
manager = self.stream_managers[channel_id]
|
|
if manager.update_url(new_url):
|
|
return jsonify({
|
|
'message': 'Stream URL updated',
|
|
'channel': channel_id,
|
|
'url': new_url
|
|
})
|
|
return jsonify({
|
|
'message': 'URL unchanged',
|
|
'channel': channel_id,
|
|
'url': new_url
|
|
})
|
|
|
|
@app.before_request
|
|
def log_request_info():
|
|
"""
|
|
Log client connections and important requests.
|
|
|
|
Log Levels:
|
|
INFO:
|
|
- First manifest request from new client
|
|
- Stream switch requests
|
|
DEBUG:
|
|
- Segment requests
|
|
- Routine manifest updates
|
|
|
|
Format:
|
|
"{client_ip} - {method} {path}"
|
|
|
|
Side Effects:
|
|
- Updates logging configuration based on request type
|
|
- Tracks client connections
|
|
"""
|
|
|
|
# Main Application Setup
|
|
if __name__ == '__main__':
|
|
# Command line argument parsing
|
|
parser = argparse.ArgumentParser(description='HLS Proxy Server with Stream Switching')
|
|
parser.add_argument(
|
|
'--url', '-u',
|
|
required=True,
|
|
help='Initial HLS stream URL to proxy'
|
|
)
|
|
parser.add_argument(
|
|
'--channel', '-c',
|
|
required=True,
|
|
help='Channel ID for the stream (default: default)'
|
|
)
|
|
parser.add_argument(
|
|
'--port', '-p',
|
|
type=int,
|
|
default=5000,
|
|
help='Local port to serve proxy on (default: 5000)'
|
|
)
|
|
parser.add_argument(
|
|
'--host', '-H',
|
|
default='0.0.0.0',
|
|
help='Interface to bind server to (default: all interfaces)'
|
|
)
|
|
parser.add_argument(
|
|
'--user-agent', '-ua',
|
|
help='Custom User-Agent string'
|
|
)
|
|
parser.add_argument(
|
|
'--debug',
|
|
action='store_true',
|
|
help='Enable debug logging'
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if args.debug else logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
|
|
try:
|
|
# Initialize proxy server
|
|
proxy = ProxyServer(user_agent=args.user_agent)
|
|
|
|
# Initialize channel with provided URL
|
|
proxy.initialize_channel(args.url, args.channel)
|
|
|
|
logging.info(f"Starting HLS proxy server on {args.host}:{args.port}")
|
|
logging.info(f"Initial stream URL: {args.url}")
|
|
logging.info(f"Channel ID: {args.channel}")
|
|
|
|
# Run Flask development server
|
|
proxy.run(host=args.host, port=args.port)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to start server: {e}")
|
|
sys.exit(1)
|
|
finally:
|
|
if 'proxy' in locals():
|
|
proxy.shutdown()
|