diff --git a/django_saml2_auth/tests/test_saml.py b/django_saml2_auth/tests/test_saml.py index a81198e..eb8563e 100644 --- a/django_saml2_auth/tests/test_saml.py +++ b/django_saml2_auth/tests/test_saml.py @@ -845,3 +845,51 @@ def test_acs_view_use_jwt_set_inactive_user( result = acs(post_request) assert result.status_code == 500 assert f"Error code: {INACTIVE_USER}" in result.content.decode() + +@pytest.mark.django_db +@responses.activate +def test_acs_view_when_use_jwt_next_url_has_query_parameters( + settings: SettingsWrapper, + monkeypatch: "MonkeyPatch", # type: ignore # noqa: F821 +): + """Test Acs view when login_next_url has query parameters in the session""" + responses.add(responses.GET, METADATA_URL1, body=METADATA1) + settings.SAML2_AUTH = { + "ASSERTION_URL": "https://api.example.com", + "DEFAULT_NEXT_URL": "default_next_url", + "USE_JWT": True, + "JWT_SECRET": "JWT_SECRET", + "JWT_ALGORITHM": "HS256", + "TRIGGER": { + "BEFORE_LOGIN": None, + "AFTER_LOGIN": None, + "GET_METADATA_AUTO_CONF_URLS": GET_METADATA_AUTO_CONF_URLS, + }, + } + post_request = RequestFactory().post(METADATA_URL1, {"SAMLResponse": "SAML RESPONSE"}) + + monkeypatch.setattr( + Saml2Client, "parse_authn_request_response", mock_parse_authn_request_response + ) + + created, mock_user = user.get_or_create_user( + {"username": "test@example.com", "first_name": "John", "last_name": "Doe"} + ) + + monkeypatch.setattr( + user, + "get_or_create_user", + ( + created, + mock_user, + ), + ) + + middleware = SessionMiddleware(MagicMock()) + middleware.process_request(post_request) + post_request.session["login_next_url"] = "/endpoint/?query=param&another=param" + post_request.session.save() + + result = acs(post_request) + assert result["Location"].count("?") == 1 + assert result["Location"].count("&") == 2 diff --git a/django_saml2_auth/views.py b/django_saml2_auth/views.py index 05f9462..ae91686 100644 --- a/django_saml2_auth/views.py +++ b/django_saml2_auth/views.py @@ -179,8 +179,6 @@ def acs(request: HttpRequest): custom_token_query_trigger = dictor(saml2_auth_settings, "TRIGGER.CUSTOM_TOKEN_QUERY") if custom_token_query_trigger: query = run_hook(custom_token_query_trigger, jwt_token) - else: - query = f"?token={jwt_token}" # Use JWT auth to send token to frontend frontend_url = dictor(saml2_auth_settings, "FRONTEND_URL", next_url) @@ -188,6 +186,21 @@ def acs(request: HttpRequest): if custom_frontend_url_trigger: frontend_url = run_hook(custom_frontend_url_trigger, relay_state) # type: ignore + parsed_url = urlparse.urlparse(frontend_url) + new_parse = list(parsed_url) # urlparse.urlparse returns a read-only tuple + if not custom_token_query_trigger: + # We run it here in order to make sure that if a custom token trigger function does exist, + # it runs before the custom frontend url trigger function. + existing_query = urlparse.parse_qs(parsed_url.query) + existing_query.setdefault("token", []).append(jwt_token) + query = urlparse.urlencode(existing_query) + new_parse[4] = query # The query field is the 5th item or 4th index + # We put this here because if people were returning weird strings for the query, + # they might no longer work when using urlunparse + destination_url = urlparse.urlunpa + else: + + return HttpResponseRedirect(frontend_url + query)