from channels.db import database_sync_to_async from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser from rest_framework_simplejwt.exceptions import InvalidToken, TokenError from rest_framework_simplejwt.tokens import UntypedToken from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication from accounts.models import CustomUser from channels.middleware import BaseMiddleware from channels.auth import AuthMiddlewareStack from django.db import close_old_connections from urllib.parse import parse_qs from jwt import decode as jwt_decode from django.conf import settings @database_sync_to_async def get_user(validated_token): try: user = get_user_model().objects.get(id=validated_token["user_id"]) # return get_user_model().objects.get(id=toke_id) print(f"{user}") return user except CustomUser.DoesNotExist: return AnonymousUser() class JwtAuthMiddleware(BaseMiddleware): async def __call__(self, scope, receive, send): # Close old database connections to prevent usage of timed out connections close_old_connections() # Get the token token = "" try: token = parse_qs(scope["query_string"].decode("utf8"))["token"][0] except: pass # No token in query string, will proceed as AnonymousUser # Try to authenticate the user try: # This will automatically validate the token and raise an error if token is invalid UntypedToken(token) except (InvalidToken, TokenError) as e: # Token is invalid, will proceed as AnonymousUser print(e) scope["user"] = AnonymousUser() else: # Then token is valid, decode it decoded_data = jwt_decode( token, settings.SECRET_KEY, algorithms=["HS256"]) print(decoded_data) # Get the user using ID scope["user"] = await get_user(validated_token=decoded_data) return await super().__call__(scope, receive, send) def JwtAuthMiddlewareStack(inner): return JwtAuthMiddleware(AuthMiddlewareStack(inner))