diff --git a/dispatcharr/asgi.py b/dispatcharr/asgi.py index 5e60f635..45923fe0 100644 --- a/dispatcharr/asgi.py +++ b/dispatcharr/asgi.py @@ -1,14 +1,17 @@ +import django import os from django.core.asgi import get_asgi_application from channels.routing import ProtocolTypeRouter, URLRouter -from channels.auth import AuthMiddlewareStack import dispatcharr.routing os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dispatcharr.settings") +django.setup() + +from .jwt_ws_auth import JWTAuthMiddleware application = ProtocolTypeRouter({ "http": get_asgi_application(), - "websocket": AuthMiddlewareStack( + "websocket": JWTAuthMiddleware( URLRouter(dispatcharr.routing.websocket_urlpatterns) ), }) diff --git a/dispatcharr/consumers.py b/dispatcharr/consumers.py index f7d4a47c..4e21bdae 100644 --- a/dispatcharr/consumers.py +++ b/dispatcharr/consumers.py @@ -6,9 +6,15 @@ logger = logging.getLogger(__name__) class MyWebSocketConsumer(AsyncWebsocketConsumer): async def connect(self): + self.room_name = "updates" + + user = self.scope["user"] + if not user.is_authenticated: + await self.close() + return + try: await self.accept() - self.room_name = "updates" await self.channel_layer.group_add(self.room_name, self.channel_name) # Send a connection confirmation to the client with consistent format await self.send(text_data=json.dumps({ diff --git a/dispatcharr/jwt_ws_auth.py b/dispatcharr/jwt_ws_auth.py new file mode 100644 index 00000000..3c7afeab --- /dev/null +++ b/dispatcharr/jwt_ws_auth.py @@ -0,0 +1,36 @@ +from urllib.parse import parse_qs +from channels.middleware import BaseMiddleware +from channels.db import database_sync_to_async +from rest_framework_simplejwt.tokens import UntypedToken +from django.contrib.auth.models import AnonymousUser +from django.contrib.auth import get_user_model +from rest_framework_simplejwt.exceptions import InvalidToken, TokenError +from rest_framework_simplejwt.authentication import JWTAuthentication + +User = get_user_model() + +@database_sync_to_async +def get_user(validated_token): + try: + jwt_auth = JWTAuthentication() + user = jwt_auth.get_user(validated_token) + return user + except: + return AnonymousUser() + +class JWTAuthMiddleware(BaseMiddleware): + async def __call__(self, scope, receive, send): + try: + # Extract the token from the query string + query_string = parse_qs(scope["query_string"].decode()) + token = query_string.get("token", [None])[0] + + if token is not None: + validated_token = JWTAuthentication().get_validated_token(token) + scope["user"] = await get_user(validated_token) + else: + scope["user"] = AnonymousUser() + except (InvalidToken, TokenError): + scope["user"] = AnonymousUser() + + return await super().__call__(scope, receive, send) diff --git a/frontend/src/WebSocket.jsx b/frontend/src/WebSocket.jsx index d12665fd..5e80f2f9 100644 --- a/frontend/src/WebSocket.jsx +++ b/frontend/src/WebSocket.jsx @@ -7,7 +7,6 @@ import React, { useMemo, useCallback, } from 'react'; -import useStreamsStore from './store/streams'; import { notifications } from '@mantine/notifications'; import useChannelsStore from './store/channels'; import usePlaylistsStore from './store/playlists'; @@ -15,8 +14,9 @@ import useEPGsStore from './store/epgs'; import { Box, Button, Stack, Alert, Group } from '@mantine/core'; import API from './api'; import useSettingsStore from './store/settings'; +import useAuthStore from './store/auth'; -export const WebsocketContext = createContext([false, () => { }, null]); +export const WebsocketContext = createContext([false, () => {}, null]); export const WebsocketProvider = ({ children }) => { const [isReady, setIsReady] = useState(false); @@ -28,10 +28,21 @@ export const WebsocketProvider = ({ children }) => { const maxReconnectAttempts = 5; const initialBackoffDelay = 1000; // 1 second initial delay const env_mode = useSettingsStore((s) => s.environment.env_mode); + const accessToken = useAuthStore((s) => s.accessToken); + + const epgs = useEPGsStore((s) => s.epgs); + const updateEPG = useEPGsStore((s) => s.updateEPG); + const updateEPGProgress = useEPGsStore((s) => s.updateEPGProgress); + + const playlists = usePlaylistsStore((s) => s.playlists); + const updatePlaylist = usePlaylistsStore((s) => s.updatePlaylist); // Calculate reconnection delay with exponential backoff const getReconnectDelay = useCallback(() => { - return Math.min(initialBackoffDelay * Math.pow(1.5, reconnectAttempts), 30000); // max 30 seconds + return Math.min( + initialBackoffDelay * Math.pow(1.5, reconnectAttempts), + 30000 + ); // max 30 seconds }, [reconnectAttempts]); // Clear any existing reconnect timers @@ -50,15 +61,15 @@ export const WebsocketProvider = ({ children }) => { // In development mode, connect directly to the WebSocket server on port 8001 if (env_mode === 'dev') { - return `${protocol}//${host}:8001/ws/`; + return `${protocol}//${host}:8001/ws/?token=${accessToken}`; } else { // In production mode, use the same port as the main application // This allows nginx to handle the WebSocket forwarding return appPort - ? `${protocol}//${host}:${appPort}/ws/` - : `${protocol}//${host}/ws/`; + ? `${protocol}//${host}:${appPort}/ws/?token=${accessToken}` + : `${protocol}//${host}/ws/?token=${accessToken}`; } - }, [env_mode]); + }, [env_mode, accessToken]); // Function to handle websocket connection const connectWebSocket = useCallback(() => { @@ -76,12 +87,14 @@ export const WebsocketProvider = ({ children }) => { try { ws.current.close(); } catch (e) { - console.warn("Error closing existing WebSocket:", e); + console.warn('Error closing existing WebSocket:', e); } } try { - console.log(`Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})...`); + console.log( + `Attempting WebSocket connection (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})...` + ); // Use the function to get the correct WebSocket URL const wsUrl = getWebSocketUrl(); @@ -91,42 +104,50 @@ export const WebsocketProvider = ({ children }) => { const socket = new WebSocket(wsUrl); socket.onopen = () => { - console.log("WebSocket connected successfully"); + console.log('WebSocket connected successfully'); setIsReady(true); setConnectionError(null); setReconnectAttempts(0); }; socket.onerror = (error) => { - console.error("WebSocket connection error:", error); + console.error('WebSocket connection error:', error); // Don't show error notification on initial page load, // only show it after a connection was established then lost if (reconnectAttempts > 0 || isReady) { - setConnectionError("Failed to connect to WebSocket server."); + setConnectionError('Failed to connect to WebSocket server.'); } else { - console.log("Initial connection attempt failed, will retry..."); + console.log('Initial connection attempt failed, will retry...'); } }; socket.onclose = (event) => { - console.warn("WebSocket connection closed", event); + console.warn('WebSocket connection closed', event); setIsReady(false); // Only attempt reconnect if we haven't reached max attempts if (reconnectAttempts < maxReconnectAttempts) { const delay = getReconnectDelay(); - setConnectionError(`Connection lost. Reconnecting in ${Math.ceil(delay / 1000)} seconds...`); - console.log(`Scheduling reconnect in ${delay}ms (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})...`); + setConnectionError( + `Connection lost. Reconnecting in ${Math.ceil(delay / 1000)} seconds...` + ); + console.log( + `Scheduling reconnect in ${delay}ms (attempt ${reconnectAttempts + 1}/${maxReconnectAttempts})...` + ); // Store timer reference so we can cancel it if needed reconnectTimerRef.current = setTimeout(() => { - setReconnectAttempts(prev => prev + 1); + setReconnectAttempts((prev) => prev + 1); connectWebSocket(); }, delay); } else { - setConnectionError("Maximum reconnection attempts reached. Please reload the page."); - console.error("Maximum reconnection attempts reached. WebSocket connection failed."); + setConnectionError( + 'Maximum reconnection attempts reached. Please reload the page.' + ); + console.error( + 'Maximum reconnection attempts reached. WebSocket connection failed.' + ); } }; @@ -137,7 +158,10 @@ export const WebsocketProvider = ({ children }) => { // Handle connection_established event if (parsedEvent.type === 'connection_established') { - console.log('WebSocket connection established:', parsedEvent.data?.message); + console.log( + 'WebSocket connection established:', + parsedEvent.data?.message + ); // Don't need to do anything else for this event type return; } @@ -167,8 +191,9 @@ export const WebsocketProvider = ({ children }) => { // Update the playlist status whenever we receive a status update // Not just when progress is 100% or status is pending_setup if (parsedEvent.data.status && parsedEvent.data.account) { - const playlistsState = usePlaylistsStore.getState(); - const playlist = playlistsState.playlists.find(p => p.id === parsedEvent.data.account); + const playlist = playlists.find( + (p) => p.id === parsedEvent.data.account + ); if (playlist) { // When we receive a "success" status with 100% progress, this is a completed refresh @@ -176,15 +201,19 @@ export const WebsocketProvider = ({ children }) => { const updateData = { ...playlist, status: parsedEvent.data.status, - last_message: parsedEvent.data.message || playlist.last_message + last_message: + parsedEvent.data.message || playlist.last_message, }; // Update the timestamp when we complete a successful refresh - if (parsedEvent.data.status === 'success' && parsedEvent.data.progress === 100) { + if ( + parsedEvent.data.status === 'success' && + parsedEvent.data.progress === 100 + ) { updateData.updated_at = new Date().toISOString(); } - playlistsState.updatePlaylist(updateData); + updatePlaylist(updateData); } } break; @@ -201,12 +230,11 @@ export const WebsocketProvider = ({ children }) => { // If source_id is provided, update that specific EPG's status if (parsedEvent.data.source_id) { - const epgsState = useEPGsStore.getState(); - const epg = epgsState.epgs[parsedEvent.data.source_id]; + const epg = epgs[parsedEvent.data.source_id]; if (epg) { - epgsState.updateEPG({ + updateEPG({ ...epg, - status: 'success' + status: 'success', }); } } @@ -221,13 +249,19 @@ export const WebsocketProvider = ({ children }) => { }); // Check if we have associations data and use the more efficient batch API - if (parsedEvent.data.associations && parsedEvent.data.associations.length > 0) { + if ( + parsedEvent.data.associations && + parsedEvent.data.associations.length > 0 + ) { API.batchSetEPG(parsedEvent.data.associations); } break; case 'm3u_profile_test': - setProfilePreview(parsedEvent.data.search_preview, parsedEvent.data.result); + setProfilePreview( + parsedEvent.data.search_preview, + parsedEvent.data.result + ); break; case 'recording_started': @@ -254,13 +288,12 @@ export const WebsocketProvider = ({ children }) => { // Update EPG status in store if (parsedEvent.data.source_id) { - const epgsState = useEPGsStore.getState(); - const epg = epgsState.epgs[parsedEvent.data.source_id]; + const epg = epgs[parsedEvent.data.source_id]; if (epg) { - epgsState.updateEPG({ + updateEPG({ ...epg, status: 'error', - last_message: parsedEvent.data.message + last_message: parsedEvent.data.message, }); } } @@ -268,28 +301,33 @@ export const WebsocketProvider = ({ children }) => { case 'epg_refresh': // Update the store with progress information - const epgsState = useEPGsStore.getState(); - epgsState.updateEPGProgress(parsedEvent.data); + updateEPGProgress(parsedEvent.data); // If we have source_id/account info, update the EPG source status if (parsedEvent.data.source_id || parsedEvent.data.account) { - const sourceId = parsedEvent.data.source_id || parsedEvent.data.account; - const epg = epgsState.epgs[sourceId]; + const sourceId = + parsedEvent.data.source_id || parsedEvent.data.account; + const epg = epgs[sourceId]; if (epg) { // Check for any indication of an error (either via status or error field) - const hasError = parsedEvent.data.status === "error" || + const hasError = + parsedEvent.data.status === 'error' || !!parsedEvent.data.error || - (parsedEvent.data.message && parsedEvent.data.message.toLowerCase().includes("error")); + (parsedEvent.data.message && + parsedEvent.data.message.toLowerCase().includes('error')); if (hasError) { // Handle error state - const errorMessage = parsedEvent.data.error || parsedEvent.data.message || "Unknown error occurred"; + const errorMessage = + parsedEvent.data.error || + parsedEvent.data.message || + 'Unknown error occurred'; - epgsState.updateEPG({ + updateEPG({ ...epg, status: 'error', - last_message: errorMessage + last_message: errorMessage, }); // Show notification for the error @@ -301,14 +339,15 @@ export const WebsocketProvider = ({ children }) => { } // Update status on completion only if no errors else if (parsedEvent.data.progress === 100) { - epgsState.updateEPG({ + updateEPG({ ...epg, status: parsedEvent.data.status || 'success', - last_message: parsedEvent.data.message || epg.last_message + last_message: + parsedEvent.data.message || epg.last_message, }); // Only show success notification if we've finished parsing programs and had no errors - if (parsedEvent.data.action === "parsing_programs") { + if (parsedEvent.data.action === 'parsing_programs') { notifications.show({ title: 'EPG Processing Complete', message: 'EPG data has been updated successfully', @@ -323,29 +362,41 @@ export const WebsocketProvider = ({ children }) => { break; default: - console.error(`Unknown websocket event type: ${parsedEvent.data?.type}`); + console.error( + `Unknown websocket event type: ${parsedEvent.data?.type}` + ); break; } } catch (error) { - console.error('Error processing WebSocket message:', error, event.data); + console.error( + 'Error processing WebSocket message:', + error, + event.data + ); } }; ws.current = socket; } catch (error) { - console.error("Error creating WebSocket connection:", error); + console.error('Error creating WebSocket connection:', error); setConnectionError(`WebSocket error: ${error.message}`); // Schedule a reconnect if we haven't reached max attempts if (reconnectAttempts < maxReconnectAttempts) { const delay = getReconnectDelay(); reconnectTimerRef.current = setTimeout(() => { - setReconnectAttempts(prev => prev + 1); + setReconnectAttempts((prev) => prev + 1); connectWebSocket(); }, delay); } } - }, [reconnectAttempts, clearReconnectTimer, getReconnectDelay, getWebSocketUrl, isReady]); + }, [ + reconnectAttempts, + clearReconnectTimer, + getReconnectDelay, + getWebSocketUrl, + isReady, + ]); // Initial connection and cleanup useEffect(() => { @@ -355,7 +406,7 @@ export const WebsocketProvider = ({ children }) => { clearReconnectTimer(); // Clear any pending reconnect timers if (ws.current) { - console.log("Closing WebSocket connection due to component unmount"); + console.log('Closing WebSocket connection due to component unmount'); ws.current.onclose = null; // Remove handlers to avoid reconnection ws.current.close(); ws.current = null; @@ -364,7 +415,6 @@ export const WebsocketProvider = ({ children }) => { }, [connectWebSocket, clearReconnectTimer]); const setChannelStats = useChannelsStore((s) => s.setChannelStats); - const fetchChannelGroups = useChannelsStore((s) => s.fetchChannelGroups); const fetchPlaylists = usePlaylistsStore((s) => s.fetchPlaylists); const setRefreshProgress = usePlaylistsStore((s) => s.setRefreshProgress); const setProfilePreview = usePlaylistsStore((s) => s.setProfilePreview); @@ -377,22 +427,51 @@ export const WebsocketProvider = ({ children }) => { return ( - {connectionError && !isReady && reconnectAttempts >= maxReconnectAttempts && ( - - {connectionError} - - - )} - {connectionError && !isReady && reconnectAttempts < maxReconnectAttempts && reconnectAttempts > 0 && ( - - {connectionError} - - )} + {connectionError && + !isReady && + reconnectAttempts >= maxReconnectAttempts && ( + + {connectionError} + + + )} + {connectionError && + !isReady && + reconnectAttempts < maxReconnectAttempts && + reconnectAttempts > 0 && ( + + {connectionError} + + )} {children} );