From 295798b9654ef1747e29612ec7b9074977e0a69a Mon Sep 17 00:00:00 2001 From: Keannu Bernasol Date: Sun, 17 Sep 2023 20:09:10 +0800 Subject: [PATCH] Added JWT middleware for websockets --- stude/config/asgi.py | 11 +++--- stude/config/auth_middleware.py | 60 +++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 stude/config/auth_middleware.py diff --git a/stude/config/asgi.py b/stude/config/asgi.py index 6384e94..1b83f2e 100644 --- a/stude/config/asgi.py +++ b/stude/config/asgi.py @@ -6,6 +6,7 @@ from channels.auth import AuthMiddlewareStack from django.core.asgi import get_asgi_application import api.routing from django.urls import re_path +from config.auth_middleware import JwtAuthMiddlewareStack os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') # Initialize Django ASGI application early to ensure the AppRegistry @@ -14,11 +15,9 @@ django_asgi_app = get_asgi_application() application = ProtocolTypeRouter({ "http": django_asgi_app, - 'websocket': AllowedHostsOriginValidator( - AuthMiddlewareStack( - URLRouter( - [re_path(r'ws/', URLRouter(api.routing.websocket_urlpatterns))] - ) + 'websocket': + JwtAuthMiddlewareStack( + URLRouter(api.routing.websocket_urlpatterns) ), - ) + }) diff --git a/stude/config/auth_middleware.py b/stude/config/auth_middleware.py new file mode 100644 index 0000000..2c32165 --- /dev/null +++ b/stude/config/auth_middleware.py @@ -0,0 +1,60 @@ +from channels.db import database_sync_to_async +from django.contrib.auth import get_user_model +from django.contrib.auth.models import AnonymousUser +from rest_framework_simplejwt.exceptions import InvalidToken, TokenError +from rest_framework_simplejwt.tokens import UntypedToken +from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication +from accounts.models import CustomUser +from channels.middleware import BaseMiddleware +from channels.auth import AuthMiddlewareStack +from django.db import close_old_connections +from urllib.parse import parse_qs +from jwt import decode as jwt_decode +from django.conf import settings + + +@database_sync_to_async +def get_user(validated_token): + try: + user = get_user_model().objects.get(id=validated_token["user_id"]) + # return get_user_model().objects.get(id=toke_id) + print(f"{user}") + return user + + except CustomUser.DoesNotExist: + return AnonymousUser() + + +class JwtAuthMiddleware(BaseMiddleware): + async def __call__(self, scope, receive, send): + # Close old database connections to prevent usage of timed out connections + close_old_connections() + + # Get the token + token = "" + try: + token = parse_qs(scope["query_string"].decode("utf8"))["token"][0] + except: + pass # No token in query string, will proceed as AnonymousUser + + # Try to authenticate the user + try: + # This will automatically validate the token and raise an error if token is invalid + UntypedToken(token) + except (InvalidToken, TokenError) as e: + # Token is invalid, will proceed as AnonymousUser + print(e) + scope["user"] = AnonymousUser() + else: + # Then token is valid, decode it + decoded_data = jwt_decode( + token, settings.SECRET_KEY, algorithms=["HS256"]) + print(decoded_data) + + # Get the user using ID + scope["user"] = await get_user(validated_token=decoded_data) + return await super().__call__(scope, receive, send) + + +def JwtAuthMiddlewareStack(inner): + return JwtAuthMiddleware(AuthMiddlewareStack(inner))