Dispatcharr/dispatcharr/jwt_ws_auth.py

47 lines
1.9 KiB
Python

from urllib.parse import parse_qs
from channels.middleware import BaseMiddleware
from channels.db import database_sync_to_async
from rest_framework_simplejwt.tokens import UntypedToken
from django.contrib.auth.models import AnonymousUser
from django.contrib.auth import get_user_model
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from rest_framework_simplejwt.authentication import JWTAuthentication
import logging
logger = logging.getLogger(__name__)
User = get_user_model()
@database_sync_to_async
def get_user(validated_token):
try:
jwt_auth = JWTAuthentication()
user = jwt_auth.get_user(validated_token)
return user
except User.DoesNotExist:
logger.warning(f"User from token does not exist. User ID: {validated_token.get('user_id', 'unknown')}")
return AnonymousUser()
except Exception as e:
logger.error(f"Error getting user from token: {str(e)}")
return AnonymousUser()
class JWTAuthMiddleware(BaseMiddleware):
async def __call__(self, scope, receive, send):
try:
# Extract the token from the query string
query_string = parse_qs(scope["query_string"].decode())
token = query_string.get("token", [None])[0]
if token is not None:
try:
validated_token = JWTAuthentication().get_validated_token(token)
scope["user"] = await get_user(validated_token)
except (InvalidToken, TokenError) as e:
logger.warning(f"Invalid token: {str(e)}")
scope["user"] = AnonymousUser()
else:
scope["user"] = AnonymousUser()
except Exception as e:
logger.error(f"Error in JWT authentication: {str(e)}")
scope["user"] = AnonymousUser()
return await super().__call__(scope, receive, send)