Added JWT middleware for websockets

This commit is contained in:
Keannu Bernasol 2023-09-17 20:09:10 +08:00
parent 532460e1ba
commit 295798b965
2 changed files with 65 additions and 6 deletions

View file

@ -6,6 +6,7 @@ from channels.auth import AuthMiddlewareStack
from django.core.asgi import get_asgi_application
import api.routing
from django.urls import re_path
from config.auth_middleware import JwtAuthMiddlewareStack
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
# Initialize Django ASGI application early to ensure the AppRegistry
@ -14,11 +15,9 @@ django_asgi_app = get_asgi_application()
application = ProtocolTypeRouter({
"http": django_asgi_app,
'websocket': AllowedHostsOriginValidator(
AuthMiddlewareStack(
URLRouter(
[re_path(r'ws/', URLRouter(api.routing.websocket_urlpatterns))]
)
'websocket':
JwtAuthMiddlewareStack(
URLRouter(api.routing.websocket_urlpatterns)
),
)
})

View file

@ -0,0 +1,60 @@
from channels.db import database_sync_to_async
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from rest_framework_simplejwt.tokens import UntypedToken
from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication
from accounts.models import CustomUser
from channels.middleware import BaseMiddleware
from channels.auth import AuthMiddlewareStack
from django.db import close_old_connections
from urllib.parse import parse_qs
from jwt import decode as jwt_decode
from django.conf import settings
@database_sync_to_async
def get_user(validated_token):
try:
user = get_user_model().objects.get(id=validated_token["user_id"])
# return get_user_model().objects.get(id=toke_id)
print(f"{user}")
return user
except CustomUser.DoesNotExist:
return AnonymousUser()
class JwtAuthMiddleware(BaseMiddleware):
async def __call__(self, scope, receive, send):
# Close old database connections to prevent usage of timed out connections
close_old_connections()
# Get the token
token = ""
try:
token = parse_qs(scope["query_string"].decode("utf8"))["token"][0]
except:
pass # No token in query string, will proceed as AnonymousUser
# Try to authenticate the user
try:
# This will automatically validate the token and raise an error if token is invalid
UntypedToken(token)
except (InvalidToken, TokenError) as e:
# Token is invalid, will proceed as AnonymousUser
print(e)
scope["user"] = AnonymousUser()
else:
# Then token is valid, decode it
decoded_data = jwt_decode(
token, settings.SECRET_KEY, algorithms=["HS256"])
print(decoded_data)
# Get the user using ID
scope["user"] = await get_user(validated_token=decoded_data)
return await super().__call__(scope, receive, send)
def JwtAuthMiddlewareStack(inner):
return JwtAuthMiddleware(AuthMiddlewareStack(inner))