Skip to content

Commit 302fe61

Browse files
sgabbmostafa
andauthored
Custom JWT plugin (& RelayState small fix) (#35)
* fix: check RelayState exist but not a token - RelayState can also be a "/" (%2F char) * feat(trigger): add custom token query - can pas custom function to create token query in settings - update README - update Contributors * Update README.md Better CUSTOM_TOKEN_QUERY description Co-authored-by: Mostafa Moradian <[email protected]> * feat(jwt): better check of RelayState - check if RelayState is a token before trying to decode it - add test of is_jwt_well_formed * feat(trigger): add custom jwt creator and decoder - use custom trigger or default function for jwt management - update README - update tests with new functions * fix: remove unused imports * Update README.md Change jwt to JWT Co-authored-by: Mostafa Moradian <[email protected]>
1 parent 9213bee commit 302fe61

File tree

7 files changed

+145
-20
lines changed

7 files changed

+145
-20
lines changed

AUTHORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ an issue.
5151
- [Akshit Dhar](https://github.com/akshit-wwstay)
5252
- [Jean Vincent](https://github.com/jean-sh)
5353
- [Søren Howe Gersager](https://github.com/syre)
54+
- [Gabrio Mauri](https://github.com/sgabb)

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ For IdP-initiated SSO, the user will be created if it doesn't exist, but for SP-
177177
| **TRIGGER.BEFORE\_LOGIN** | A method to be called when an existing user logs in. This method will be called before the user is logged in and after the SAML2 identity provider returns user attributes. This method should accept ONE parameter of user dict. | `str` | `None` | `my_app.models.users.before_login` |
178178
| **TRIGGER.AFTER\_LOGIN** | A method to be called when an existing user logs in. This method will be called after the user is logged in and after the SAML2 identity provider returns user attributes. This method should accept TWO parameters of session and user dict. | `str` | `None` | `my_app.models.users.after_login` |
179179
| **TRIGGER.GET\_METADATA\_AUTO\_CONF\_URLS** | A hook function that returns a list of metadata Autoconf URLs. This can override the `METADATA_AUTO_CONF_URL` to enumerate all existing metadata autoconf URLs. | `str` | `None` | `my_app.models.users.get_metadata_autoconf_urls` |
180+
| **TRIGGER.CUSTOM\_DECODE\_JWT** | A hook function to decode the user JWT. This method will be called instead of the `decode_jwt_token` default function and should return the user_model.USERNAME_FIELD. This method accepts one parameter: `token`. | `str` | `None` | `my_app.models.users.decode_custom_token` |
181+
| **TRIGGER.CUSTOM\_CREATE\_JWT** | A hook function to create a custom JWT for the user. This method will be called instead of the `create_jwt_token` default function and should return the token. This method accepts one parameter: `user`. | `str` | `None` | `my_app.models.users.create_custom_token` |
182+
| **TRIGGER.CUSTOM\_TOKEN\_QUERY** | A hook function to create a custom query params with the JWT for the user. This method will be called after `CUSTOM_CREATE_JWT` to populate a query and attach it to a URL; should return the query params containing the token (e.g., `?token=encoded.jwt.token`). This method accepts one parameter: `token`. | `str` | `None` | `my_app.models.users.get_custom_token_query` |
180183
| **ASSERTION\_URL** | A URL to validate incoming SAML responses against. By default, `django-saml2-auth` will validate the SAML response's Service Provider address against the actual HTTP request's host and scheme. If this value is set, it will validate against `ASSERTION_URL` instead - perfect for when Django is running behind a reverse proxy. | `str` | `https://example.com` | |
181184
| **ENTITY\_ID** | The optional entity ID string to be passed in the 'Issuer' element of authentication request, if required by the IDP. | `str` | `None` | `https://exmaple.com/sso/acs` |
182185
| **NAME\_ID\_FORMAT** | Set to the string `'None'`, to exclude sending the `'Format'` property of the `'NameIDPolicy'` element in authentication requests. | `str` | `<urn:oasis:names:tc:SAML:2.0:nameid-format:transient>` | |
@@ -240,6 +243,25 @@ Otherwise if you want to use your PKI key-pair to sign JWT tokens, use either of
240243

241244
*Note:* If both PKI fields and `JWT_SECRET` are defined, the `JWT_ALGORITHM` decides which method to use for signing tokens.
242245

246+
### Custom token triggers
247+
248+
This is an example of the functions that could be passed to the `TRIGGER.CUSTOM_CREATE_JWT` (it uses the [DRF Simple JWT library](https://github.com/jazzband/djangorestframework-simplejwt/blob/master/docs/index.rst)) and to `TRIGGER.CUSTOM_TOKEN_QUERY`:
249+
250+
``` python
251+
from rest_framework_simplejwt.tokens import RefreshToken
252+
253+
254+
def get_custom_jwt(user):
255+
"""Create token for user and return it"""
256+
return RefreshToken.for_user(user)
257+
258+
259+
def get_custom_token_query(refresh):
260+
"""Create url query with refresh and access token"""
261+
return "?%s%s%s%s%s" % ("refresh=", str(refresh), "&", "access=", str(refresh.access_token))
262+
263+
```
264+
243265
## Customize
244266

245267
The default permission `denied`, `error` and user `welcome` page can be overridden.

django_saml2_auth/tests/test_user.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import pytest
44
from django.contrib.auth.models import Group
55
from django_saml2_auth.exceptions import SAMLAuthError
6-
from django_saml2_auth.user import (create_jwt_token, create_new_user,
7-
decode_jwt_token, get_or_create_user,
6+
from django_saml2_auth.user import (create_custom_or_default_jwt, create_new_user,
7+
decode_custom_or_default_jwt, get_or_create_user,
88
get_user, get_user_id)
99
from jwt.exceptions import PyJWTError
1010
from pytest_django.fixtures import SettingsWrapper
@@ -306,8 +306,8 @@ def test_create_and_decode_jwt_token_success(
306306
"""
307307
settings.SAML2_AUTH = saml2_settings
308308

309-
jwt_token = create_jwt_token("[email protected]")
310-
user_id = decode_jwt_token(jwt_token)
309+
jwt_token = create_custom_or_default_jwt("[email protected]")
310+
user_id = decode_custom_or_default_jwt(jwt_token)
311311
assert user_id == "[email protected]"
312312

313313

@@ -347,7 +347,7 @@ def test_create_jwt_token_with_incorrect_jwt_settings(
347347
settings.SAML2_AUTH = saml2_settings
348348

349349
with pytest.raises(SAMLAuthError) as exc_info:
350-
create_jwt_token("[email protected]")
350+
create_custom_or_default_jwt("[email protected]")
351351

352352
assert str(exc_info.value) == error_msg
353353

@@ -393,15 +393,15 @@ def test_decode_jwt_token_with_incorrect_jwt_settings(
393393
settings.SAML2_AUTH = saml2_settings
394394

395395
with pytest.raises(SAMLAuthError) as exc_info:
396-
decode_jwt_token("WHATEVER")
396+
decode_custom_or_default_jwt("WHATEVER")
397397

398398
assert str(exc_info.value) == error_msg
399399

400400

401401
def test_decode_jwt_token_failure():
402402
"""Test decode_jwt_token function by passing an invalid JWT token (None, in this case)."""
403403
with pytest.raises(SAMLAuthError) as exc_info:
404-
decode_jwt_token(None)
404+
decode_custom_or_default_jwt(None)
405405

406406
assert str(exc_info.value) == "Cannot decode JWT token."
407407
assert isinstance(exc_info.value.extra["exc"], PyJWTError)

django_saml2_auth/tests/test_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from django.http import HttpRequest, HttpResponse
33
from django.urls import NoReverseMatch
44
from django_saml2_auth.exceptions import SAMLAuthError
5-
from django_saml2_auth.utils import exception_handler, get_reverse, run_hook
5+
from django_saml2_auth.utils import exception_handler, get_reverse, run_hook, is_jwt_well_formed
66

77

88
def divide(a: int, b: int = 1) -> int:
@@ -124,3 +124,12 @@ def test_exception_handler_handle_exception():
124124
contents = result.content.decode("utf-8")
125125
assert result.status_code == 500
126126
assert "Reason: Internal world error!" in contents
127+
128+
129+
def test_jwt_well_formed():
130+
"""Test if passed RelayState is a well formed JWT"""
131+
token = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiI0MjQyIiwibmFtZSI6Ikplc3NpY2EgVGVtcG9yYWwiLCJuaWNrbmFtZSI6Ikplc3MifQ.EDkUUxaM439gWLsQ8a8mJWIvQtgZe0et3O3z4Fd_J8o' # noqa
132+
res = is_jwt_well_formed(token) # True
133+
assert res is True
134+
res = is_jwt_well_formed('/') # False
135+
assert res is False

django_saml2_auth/user.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,44 @@ def create_jwt_token(user_id: str) -> Optional[str]:
327327
return jwt.encode(payload, secret, algorithm=jwt_algorithm)
328328

329329

330+
def create_custom_or_default_jwt(user: Union[str, Type[Model]]):
331+
"""Create a new JWT token, eventually using custom trigger
332+
333+
Args:
334+
user (dict or str): User instance or User's username or email based on User.USERNAME_FIELD
335+
336+
Returns:
337+
Optional[str]: JWT token
338+
"""
339+
saml2_auth_settings = settings.SAML2_AUTH
340+
custom_create_jwt_trigger = dictor(saml2_auth_settings, "TRIGGER.CUSTOM_CREATE_JWT")
341+
342+
# If user is the id (user_model.USERNAME_FIELD), set it as user_id
343+
user_id = None
344+
if isinstance(user, str):
345+
user_id = user
346+
347+
# Check if there is a custom trigger for creating the JWT and URL query
348+
if custom_create_jwt_trigger:
349+
target_user = user
350+
# If user is user_id, get user instance
351+
if user_id:
352+
user_model = get_user_model()
353+
user = {
354+
user_model.USERNAME_FIELD: user_id
355+
}
356+
target_user = get_user(user)
357+
jwt_token = run_hook(custom_create_jwt_trigger, target_user)
358+
else:
359+
# If user_id is not set, retrieve it from user instance
360+
if not user_id:
361+
user_id = getattr(user, user_model.USERNAME_FIELD)
362+
# Create a new JWT token with PyJWT
363+
jwt_token = create_jwt_token(user_id)
364+
365+
return jwt_token
366+
367+
330368
def decode_jwt_token(jwt_token: str) -> Optional[str]:
331369
"""Decode a JWT token
332370
@@ -366,3 +404,24 @@ def decode_jwt_token(jwt_token: str) -> Optional[str]:
366404
"reason": "Cannot decode JWT token.",
367405
"status_code": 500
368406
})
407+
408+
409+
def decode_custom_or_default_jwt(jwt_token: str) -> Optional[str]:
410+
"""Decode a JWT token, eventually using custom trigger
411+
412+
Args:
413+
jwt_token (str): The token to decode
414+
415+
Raises:
416+
SAMLAuthError: Cannot decode JWT token.
417+
418+
Returns:
419+
Optional[str]: A user_id as str or None.
420+
"""
421+
saml2_auth_settings = settings.SAML2_AUTH
422+
custom_decode_jwt_trigger = dictor(saml2_auth_settings, "TRIGGER.CUSTOM_DECODE_JWT")
423+
if custom_decode_jwt_trigger:
424+
user_id = run_hook(custom_decode_jwt_trigger, jwt_token)
425+
else:
426+
user_id = decode_jwt_token(jwt_token)
427+
return user_id

django_saml2_auth/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import logging
6+
import base64
67
from functools import wraps
78
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple, Union
89

@@ -157,3 +158,29 @@ def wrapper(request: HttpRequest) -> HttpResponse:
157158
result = handle_exception(exc, request)
158159
return result
159160
return wrapper
161+
162+
163+
def is_jwt_well_formed(jwt: str):
164+
"""Check if JWT is well formed
165+
166+
Args:
167+
jwt (str): Json Web Token
168+
169+
Returns:
170+
Boolean: True if JWT is well formed, otherwise False
171+
"""
172+
if isinstance(jwt, str):
173+
# JWT should contain three segments, separated by two period ('.') characters.
174+
jwt_segments = jwt.split('.')
175+
if len(jwt_segments) == 3:
176+
jose_header = jwt_segments[0]
177+
# base64-encoded string length should be a multiple of 4
178+
if len(jose_header) % 4 == 0:
179+
try:
180+
jh_decoded = base64.b64decode(jose_header).decode('utf-8')
181+
if jh_decoded and jh_decoded.find('JWT') > -1:
182+
return True
183+
except Exception:
184+
return False
185+
# If tests not passed return False
186+
return False

django_saml2_auth/views.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import urllib.parse as urlparse
77
from urllib.parse import unquote
8-
import json
98

109
from dictor import dictor
1110
from django import get_version
@@ -26,8 +25,8 @@
2625
extract_user_identity, get_assertion_url,
2726
get_default_next_url, get_saml_client)
2827
from django_saml2_auth.user import (
29-
get_or_create_user, create_jwt_token, decode_jwt_token, get_user_id)
30-
from django_saml2_auth.utils import exception_handler, get_reverse, run_hook
28+
create_custom_or_default_jwt, decode_custom_or_default_jwt, get_or_create_user, get_user_id)
29+
from django_saml2_auth.utils import exception_handler, get_reverse, is_jwt_well_formed, run_hook
3130
from pkg_resources import parse_version
3231

3332

@@ -71,11 +70,15 @@ def acs(request: HttpRequest):
7170

7271
next_url = request.session.get("login_next_url") or get_default_next_url()
7372

74-
# If RelayState params is passed, it is a JWT token that identifies the user trying to login
75-
# via sp_initiated_login endpoint
73+
# A RelayState is an HTTP parameter that can be included as part of the SAML request and SAML response;
74+
# usually is meant to be an opaque identifier that is passed back without any modification or inspection,
75+
# and it is used to specify additional information to the SP or the IdP.
76+
# If RelayState params is passed, it could be JWT token that identifies the user trying to login
77+
# via sp_initiated_login endpoint, or it could be a URL used for redirection.
7678
relay_state = request.POST.get("RelayState")
77-
if relay_state:
78-
redirected_user_id = decode_jwt_token(relay_state)
79+
relay_state_is_token = is_jwt_well_formed(relay_state)
80+
if relay_state_is_token:
81+
redirected_user_id = decode_custom_or_default_jwt(relay_state)
7982

8083
# This prevents users from entering an email on the SP, but use a different email on IdP
8184
if get_user_id(user) != redirected_user_id:
@@ -97,10 +100,14 @@ def acs(request: HttpRequest):
97100
use_jwt = dictor(saml2_auth_settings, "USE_JWT", False)
98101
if use_jwt and target_user.is_active:
99102
# Create a new JWT token for IdP-initiated login (acs)
100-
jwt_token = create_jwt_token(target_user.email)
101-
# Use JWT auth to send token to frontend
102-
query = f"?token={jwt_token}"
103+
jwt_token = create_custom_or_default_jwt(target_user)
104+
custom_token_query_trigger = dictor(saml2_auth_settings, "TRIGGER.CUSTOM_TOKEN_QUERY")
105+
if custom_token_query_trigger:
106+
query = run_hook(custom_token_query_trigger, jwt_token)
107+
else:
108+
query = f"?token={jwt_token}"
103109

110+
# Use JWT auth to send token to frontend
104111
frontend_url = dictor(saml2_auth_settings, "FRONTEND_URL", next_url)
105112

106113
return HttpResponseRedirect(frontend_url + query)
@@ -134,9 +141,9 @@ def sp_initiated_login(request: HttpRequest) -> HttpResponseRedirect:
134141
# User must be created first by the IdP-initiated SSO (acs)
135142
if request.method == "GET":
136143
if request.GET.get("token"):
137-
user_id = decode_jwt_token(request.GET.get("token"))
144+
user_id = decode_custom_or_default_jwt(request.GET.get("token"))
138145
saml_client = get_saml_client(get_assertion_url(request), acs, user_id)
139-
jwt_token = create_jwt_token(user_id)
146+
jwt_token = create_custom_or_default_jwt(user_id)
140147
_, info = saml_client.prepare_for_authenticate(sign=False, relay_state=jwt_token)
141148
redirect_url = dict(info["headers"]).get("Location", "")
142149
if not redirect_url:

0 commit comments

Comments
 (0)