Disconnects are working better now.

This commit is contained in:
SergeantPanda 2025-03-10 17:56:09 -05:00
parent f3b1636579
commit fa5fc86c99
2 changed files with 268 additions and 70 deletions

View file

@ -29,7 +29,7 @@ class TSConfig(BaseConfig):
MAX_RETRIES = 3 # maximum connection retry attempts
# Buffer settings
INITIAL_BEHIND_CHUNKS = 30 # How many chunks behind to start a client
INITIAL_BEHIND_CHUNKS = 100 # How many chunks behind to start a client
CHUNK_BATCH_SIZE = 5 # How many chunks to fetch in one batch
KEEPALIVE_INTERVAL = 0.5 # Seconds between keepalive packets when at buffer head
@ -46,4 +46,5 @@ class TSConfig(BaseConfig):
CLIENT_RECORD_TTL = 5 # How long client records persist in Redis (seconds). Client will be considered MIA after this time.
CLEANUP_CHECK_INTERVAL = 1 # How often to check for disconnected clients (seconds)
CHANNEL_INIT_GRACE_PERIOD = 5 # How long to wait for first client after initialization (seconds)
CLIENT_HEARTBEAT_INTERVAL = 1 # How often to send client heartbeats (seconds)

View file

@ -16,6 +16,7 @@ from collections import deque
import time
import sys
from typing import Optional, Set, Deque, Dict
import json
from apps.proxy.config import TSConfig as Config
# Configure root logger for this module
@ -216,7 +217,7 @@ class StreamManager:
try:
# Create an initial connection to get socket
response = session.get(self.url, stream=True)
current_response = response # Store reference for cleanup
current_response = response
if response.status_code == 200:
self.connected = True
@ -224,6 +225,9 @@ class StreamManager:
self.healthy = True
logging.info("Successfully connected to stream source")
# Connection successful - START GRACE PERIOD HERE
self._set_waiting_for_clients()
# Main fetch loop
while self.running and self.connected:
if self.fetch_chunk():
@ -367,6 +371,30 @@ class StreamManager:
logging.error(f"Error in fetch_chunk: {e}")
return False
def _set_waiting_for_clients(self):
"""Set channel state to waiting for clients after successful connection"""
try:
if hasattr(self.buffer, 'channel_id') and hasattr(self.buffer, 'redis_client'):
channel_id = self.buffer.channel_id
redis_client = self.buffer.redis_client
if channel_id and redis_client:
# Set state to waiting
state_key = f"ts_proxy:channel:{channel_id}:state"
redis_client.set(state_key, "waiting_for_clients")
# Set grace period start time
grace_key = f"ts_proxy:channel:{channel_id}:grace_start"
redis_client.setex(grace_key, 120, str(time.time()))
# Get configured grace period or default
grace_period = getattr(Config, 'CHANNEL_INIT_GRACE_PERIOD', 20)
logging.info(f"Started initial connection grace period ({grace_period}s) for channel {channel_id}")
except Exception as e:
logging.error(f"Error setting waiting for clients state: {e}")
class StreamBuffer:
"""Manages stream data buffering using Redis for persistence"""
@ -551,14 +579,95 @@ class StreamBuffer:
class ClientManager:
"""Manages connected clients for a channel with cross-worker visibility"""
def __init__(self, channel_id, redis_client=None):
def __init__(self, channel_id, redis_client=None, worker_id=None):
self.channel_id = channel_id
self.redis_client = redis_client
self.worker_id = worker_id # Store worker_id directly
self.clients = set() # Local clients only
self.lock = threading.Lock()
self.last_active_time = time.time()
self.client_set_key = f"ts_proxy:channel:{channel_id}:clients"
self.client_ttl = getattr(Config, 'CLIENT_RECORD_TTL', 5)
self.client_ttl = getattr(Config, 'CLIENT_RECORD_TTL', 60)
self.heartbeat_interval = getattr(Config, 'CLIENT_HEARTBEAT_INTERVAL', 10)
self.last_heartbeat_time = {} # Track last heartbeat time per client
# Start heartbeat thread for local clients
self._start_heartbeat_thread()
def _start_heartbeat_thread(self):
"""Start thread to regularly refresh client presence in Redis"""
def heartbeat_task():
while True:
try:
# Wait for the interval
time.sleep(self.heartbeat_interval)
# Send heartbeat for all local clients
with self.lock:
if not self.clients or not self.redis_client:
continue
# Use pipeline for efficiency
pipe = self.redis_client.pipeline()
current_time = time.time()
# For each client, update its TTL and timestamp
for client_id in self.clients:
# Skip if we just sent a heartbeat recently
if client_id in self.last_heartbeat_time:
time_since_last = current_time - self.last_heartbeat_time[client_id]
if time_since_last < self.heartbeat_interval * 0.8:
continue
# Update the client's individual key with new TTL
client_key = f"ts_proxy:client:{self.channel_id}:{client_id}"
pipe.setex(client_key, self.client_ttl, str(current_time))
# Keep client in the set with TTL
pipe.sadd(self.client_set_key, client_id)
# Update last activity timestamp in a separate key
activity_key = f"ts_proxy:client:{self.channel_id}:{client_id}:last_active"
pipe.setex(activity_key, self.client_ttl, str(current_time))
# Track last heartbeat locally
self.last_heartbeat_time[client_id] = current_time
# Always refresh the TTL on the set itself
pipe.expire(self.client_set_key, self.client_ttl)
# Execute all commands atomically
pipe.execute()
# Notify channel owner of client activity
self._notify_owner_of_activity()
except Exception as e:
logging.error(f"Error in client heartbeat thread: {e}")
thread = threading.Thread(target=heartbeat_task, daemon=True)
thread.name = f"client-heartbeat-{self.channel_id}"
thread.start()
logging.debug(f"Started client heartbeat thread for channel {self.channel_id} (interval: {self.heartbeat_interval}s)")
def _notify_owner_of_activity(self):
"""Notify channel owner that clients are active on this worker"""
if not self.redis_client or not self.clients:
return
try:
# Use the stored worker_id
worker_id = self.worker_id or "unknown"
# Store count of clients on this worker
worker_key = f"ts_proxy:channel:{self.channel_id}:worker:{worker_id}"
self.redis_client.setex(worker_key, self.client_ttl, str(len(self.clients)))
# Update channel activity timestamp
activity_key = f"ts_proxy:channel:{self.channel_id}:activity"
self.redis_client.setex(activity_key, self.client_ttl, str(time.time()))
except Exception as e:
logging.error(f"Error notifying owner of client activity: {e}")
def add_client(self, client_id):
"""Add a client to this channel locally and in Redis"""
@ -568,24 +677,46 @@ class ClientManager:
# Track in Redis if available
if self.redis_client:
current_time = str(time.time())
# Add to channel's client set
self.redis_client.sadd(self.client_set_key, client_id)
# Set TTL on the whole set
self.redis_client.expire(self.client_set_key, self.client_ttl)
# Also track client individually with TTL for cleanup
# Set up client key with timestamp as value
client_key = f"ts_proxy:client:{self.channel_id}:{client_id}"
self.redis_client.setex(client_key, self.client_ttl, "1")
self.redis_client.setex(client_key, self.client_ttl, current_time)
# Also track last activity time separately
activity_key = f"ts_proxy:client:{self.channel_id}:{client_id}:last_active"
self.redis_client.setex(activity_key, self.client_ttl, current_time)
# Clear any initialization timer by removing the init_time key
init_key = f"ts_proxy:channel:{self.channel_id}:init_time"
self.redis_client.delete(init_key)
# Update worker count in Redis
self._notify_owner_of_activity()
# Also publish an event that the client connected
event_data = json.dumps({
"event": "client_connected",
"channel_id": self.channel_id,
"client_id": client_id,
"worker_id": self.worker_id or "unknown",
"timestamp": time.time()
})
self.redis_client.publish(f"ts_proxy:events:{self.channel_id}", event_data)
# Get total clients across all workers
total_clients = self.get_total_client_count()
logging.info(f"New client connected: {client_id} (local: {len(self.clients)}, total: {total_clients})")
# Record last heartbeat time
self.last_heartbeat_time[client_id] = time.time()
return len(self.clients)
def remove_client(self, client_id):
@ -593,6 +724,10 @@ class ClientManager:
with self.lock:
if client_id in self.clients:
self.clients.remove(client_id)
if client_id in self.last_heartbeat_time:
del self.last_heartbeat_time[client_id]
self.last_active_time = time.time()
# Remove from Redis
@ -600,9 +735,23 @@ class ClientManager:
# Remove from channel's client set
self.redis_client.srem(self.client_set_key, client_id)
# Delete individual client key
# Delete individual client keys
client_key = f"ts_proxy:client:{self.channel_id}:{client_id}"
self.redis_client.delete(client_key)
activity_key = f"ts_proxy:client:{self.channel_id}:{client_id}:last_active"
self.redis_client.delete(client_key, activity_key)
# Update worker count in Redis
self._notify_owner_of_activity()
# Also publish an event that the client disconnected
event_data = json.dumps({
"event": "client_disconnected",
"channel_id": self.channel_id,
"client_id": client_id,
"worker_id": self.worker_id or "unknown",
"timestamp": time.time()
})
self.redis_client.publish(f"ts_proxy:events:{self.channel_id}", event_data)
# Get remaining clients across all workers
total_clients = self.get_total_client_count()
@ -758,7 +907,64 @@ class ProxyServer:
# Start cleanup thread
self.cleanup_interval = getattr(Config, 'CLEANUP_INTERVAL', 60)
self._start_cleanup_thread()
# Start event listener for Redis pubsub messages
self._start_event_listener()
def _start_event_listener(self):
"""Listen for events from other workers"""
if not self.redis_client:
return
def event_listener():
try:
pubsub = self.redis_client.pubsub()
pubsub.psubscribe("ts_proxy:events:*")
logging.info("Started Redis event listener for client activity")
for message in pubsub.listen():
if message["type"] != "pmessage":
continue
try:
channel = message["channel"].decode("utf-8")
data = json.loads(message["data"].decode("utf-8"))
event_type = data.get("event")
channel_id = data.get("channel_id")
if channel_id and event_type:
# For owner, update client status immediately
if self.am_i_owner(channel_id):
if event_type == "client_connected":
logging.debug(f"Owner received client_connected event for channel {channel_id}")
# Reset any no-clients timer
no_clients_key = f"ts_proxy:channel:{channel_id}:no_clients_since"
self.redis_client.delete(no_clients_key)
elif event_type == "client_disconnected":
logging.debug(f"Owner received client_disconnected event for channel {channel_id}")
# Check if any clients remain
if channel_id in self.client_managers:
total = self.client_managers[channel_id].get_total_client_count()
if total == 0:
logging.info(f"No clients left after disconnect event, starting shutdown timer")
# Start the no-clients timer
no_clients_key = f"ts_proxy:channel:{channel_id}:no_clients_since"
self.redis_client.setex(no_clients_key, 60, str(time.time()))
except Exception as e:
logging.error(f"Error processing event message: {e}")
except Exception as e:
logging.error(f"Error in event listener: {e}")
time.sleep(5) # Wait before reconnecting
# Try to restart the listener
self._start_event_listener()
thread = threading.Thread(target=event_listener, daemon=True)
thread.name = "redis-event-listener"
thread.start()
def get_channel_owner(self, channel_id):
"""Get the worker ID that owns this channel with proper error handling"""
if not self.redis_client:
@ -884,7 +1090,7 @@ class ProxyServer:
self.stream_buffers[channel_id] = buffer
# Create client manager with channel_id and redis_client
client_manager = ClientManager(channel_id=channel_id, redis_client=self.redis_client)
client_manager = ClientManager(channel_id=channel_id, redis_client=self.redis_client, worker_id=self.worker_id)
self.client_managers[channel_id] = client_manager
return True
@ -905,7 +1111,7 @@ class ProxyServer:
self.stream_buffers[channel_id] = buffer
# Create client manager with channel_id and redis_client
client_manager = ClientManager(channel_id=channel_id, redis_client=self.redis_client)
client_manager = ClientManager(channel_id=channel_id, redis_client=self.redis_client, worker_id=self.worker_id)
self.client_managers[channel_id] = client_manager
return True
@ -923,8 +1129,12 @@ class ProxyServer:
logging.debug(f"Created StreamManager for channel {channel_id}")
self.stream_managers[channel_id] = stream_manager
# Create client manager with channel_id and redis_client
client_manager = ClientManager(channel_id=channel_id, redis_client=self.redis_client)
# Create client manager with channel_id, redis_client AND worker_id
client_manager = ClientManager(
channel_id=channel_id,
redis_client=self.redis_client,
worker_id=self.worker_id
)
self.client_managers[channel_id] = client_manager
# Set channel activity key (separate from lock key)
@ -938,34 +1148,19 @@ class ProxyServer:
thread.start()
logging.info(f"Started stream manager thread for channel {channel_id}")
# If we're the owner, start a grace period timer for first client
# If we're the owner, we need to set the channel state rather than starting a grace period immediately
if self.am_i_owner(channel_id):
# Set a timestamp to track initialization time
# Set channel state to "connecting"
if self.redis_client:
init_key = f"ts_proxy:channel:{channel_id}:init_time"
self.redis_client.setex(init_key, Config.CLIENT_RECORD_TTL, str(time.time()))
state_key = f"ts_proxy:channel:{channel_id}:state"
self.redis_client.set(state_key, "connecting")
# Start a timer thread to check for first client
def check_first_client():
# Wait for the grace period
time.sleep(getattr(Config, 'CHANNEL_INIT_GRACE_PERIOD', 10))
# After grace period, check if any clients connected
if channel_id in self.client_managers:
total_clients = self.client_managers[channel_id].get_total_client_count()
if total_clients == 0:
logging.info(f"No clients connected to channel {channel_id} within grace period, shutting down")
self.stop_channel(channel_id)
else:
logging.info(f"Channel {channel_id} has {total_clients} clients, staying active")
# Set connection start time for monitoring
connect_key = f"ts_proxy:channel:{channel_id}:connect_time"
self.redis_client.setex(connect_key, 60, str(time.time()))
logging.info(f"Channel {channel_id} in connecting state - will start grace period after connection")
# Start the timer thread
timer_thread = threading.Thread(target=check_first_client, daemon=True)
timer_thread.name = f"init-timer-{channel_id}"
timer_thread.start()
logging.info(f"Started initial connection grace period ({Config.CHANNEL_INIT_GRACE_PERIOD}s) for channel {channel_id}")
return True
except Exception as e:
@ -1075,9 +1270,6 @@ class ProxyServer:
def _start_cleanup_thread(self):
"""Start background thread to maintain ownership and clean up resources"""
def cleanup_task():
# Wait for initialization
time.sleep(5)
while True:
try:
# For channels we own, check total clients and cleanup as needed
@ -1086,66 +1278,71 @@ class ProxyServer:
# Extend ownership lease
self.extend_ownership(channel_id)
# Get channel state
channel_state = "unknown"
if self.redis_client:
state_bytes = self.redis_client.get(f"ts_proxy:channel:{channel_id}:state")
if state_bytes:
channel_state = state_bytes.decode('utf-8')
# Check if channel has any clients left
if channel_id in self.client_managers:
client_manager = self.client_managers[channel_id]
total_clients = client_manager.get_total_client_count()
if total_clients == 0:
# Either we're in initialization grace period or shutdown delay
init_key = f"ts_proxy:channel:{channel_id}:init_time"
no_clients_key = f"ts_proxy:channel:{channel_id}:no_clients_since"
# If in waiting_for_clients state, check if grace period expired
if channel_state == "waiting_for_clients" and total_clients == 0:
grace_key = f"ts_proxy:channel:{channel_id}:grace_start"
grace_start = None
init_time = None
if self.redis_client:
grace_value = self.redis_client.get(grace_key)
if grace_value:
grace_start = float(grace_value.decode('utf-8'))
if grace_start:
grace_period = getattr(Config, 'CHANNEL_INIT_GRACE_PERIOD', 20)
grace_elapsed = time.time() - grace_start
if grace_elapsed > grace_period:
logging.info(f"No clients connected within grace period ({grace_elapsed:.1f}s > {grace_period}s), stopping channel {channel_id}")
self.stop_channel(channel_id)
else:
logging.debug(f"Channel {channel_id} in grace period - {grace_elapsed:.1f}s of {grace_period}s elapsed, waiting for clients")
# If active and no clients, start normal shutdown procedure
elif channel_state not in ["connecting", "waiting_for_clients"] and total_clients == 0:
# Check if there's a pending no-clients timeout
key = f"ts_proxy:channel:{channel_id}:no_clients_since"
no_clients_since = None
if self.redis_client:
# Get initialization time if exists
init_value = self.redis_client.get(init_key)
if init_value:
init_time = float(init_value.decode('utf-8'))
# Get no clients timestamp if exists
no_clients_value = self.redis_client.get(no_clients_key)
no_clients_value = self.redis_client.get(key)
if no_clients_value:
no_clients_since = float(no_clients_value.decode('utf-8'))
current_time = time.time()
# Handle initialization grace period expiration
if init_time and current_time - init_time > getattr(Config, 'CHANNEL_INIT_GRACE_PERIOD', 10):
logging.info(f"No clients connected to channel {channel_id} within grace period, shutting down")
self.stop_channel(channel_id)
continue
# Handle no clients since tracking
if not no_clients_since:
# First time seeing zero clients, set timestamp
if self.redis_client:
self.redis_client.setex(no_clients_key, Config.CLIENT_RECORD_TTL, str(current_time))
self.redis_client.setex(key, Config.CLIENT_RECORD_TTL, str(current_time))
logging.info(f"No clients detected for channel {channel_id}, starting shutdown timer")
elif current_time - no_clients_since > getattr(Config, 'CHANNEL_SHUTDOWN_DELAY', 5):
# We've had no clients for the shutdown delay period
logging.info(f"No clients for {current_time - no_clients_since:.1f}s, stopping channel {channel_id}")
self.stop_channel(channel_id)
else:
# There are clients - clear any no-clients timestamp
# There are clients or we're still connecting - clear any no-clients timestamp
if self.redis_client:
self.redis_client.delete(f"ts_proxy:channel:{channel_id}:no_clients_since")
# Non-owner workers just refresh client TTLs
for channel_id, client_manager in list(self.client_managers.items()):
if not self.am_i_owner(channel_id):
client_manager.refresh_client_ttl()
# Check for orphaned channels in Redis
self._check_orphaned_channels()
# Rest of the cleanup thread...
except Exception as e:
logging.error(f"Error in cleanup thread: {e}", exc_info=True)
# Run more frequently to detect client disconnects quickly
time.sleep(getattr(Config, 'CLEANUP_CHECK_INTERVAL', 3))
time.sleep(getattr(Config, 'CLEANUP_CHECK_INTERVAL', 1))
thread = threading.Thread(target=cleanup_task, daemon=True)
thread.name = "ts-proxy-cleanup"