mirror of
https://github.com/Dispatcharr/Dispatcharr.git
synced 2026-01-23 18:54:58 +00:00
47 lines
1.9 KiB
Python
47 lines
1.9 KiB
Python
from urllib.parse import parse_qs
|
|
from channels.middleware import BaseMiddleware
|
|
from channels.db import database_sync_to_async
|
|
from rest_framework_simplejwt.tokens import UntypedToken
|
|
from django.contrib.auth.models import AnonymousUser
|
|
from django.contrib.auth import get_user_model
|
|
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
|
|
from rest_framework_simplejwt.authentication import JWTAuthentication
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
User = get_user_model()
|
|
|
|
@database_sync_to_async
|
|
def get_user(validated_token):
|
|
try:
|
|
jwt_auth = JWTAuthentication()
|
|
user = jwt_auth.get_user(validated_token)
|
|
return user
|
|
except User.DoesNotExist:
|
|
logger.warning(f"User from token does not exist. User ID: {validated_token.get('user_id', 'unknown')}")
|
|
return AnonymousUser()
|
|
except Exception as e:
|
|
logger.error(f"Error getting user from token: {str(e)}")
|
|
return AnonymousUser()
|
|
|
|
class JWTAuthMiddleware(BaseMiddleware):
|
|
async def __call__(self, scope, receive, send):
|
|
try:
|
|
# Extract the token from the query string
|
|
query_string = parse_qs(scope["query_string"].decode())
|
|
token = query_string.get("token", [None])[0]
|
|
|
|
if token is not None:
|
|
try:
|
|
validated_token = JWTAuthentication().get_validated_token(token)
|
|
scope["user"] = await get_user(validated_token)
|
|
except (InvalidToken, TokenError) as e:
|
|
logger.warning(f"Invalid token: {str(e)}")
|
|
scope["user"] = AnonymousUser()
|
|
else:
|
|
scope["user"] = AnonymousUser()
|
|
except Exception as e:
|
|
logger.error(f"Error in JWT authentication: {str(e)}")
|
|
scope["user"] = AnonymousUser()
|
|
|
|
return await super().__call__(scope, receive, send)
|