diff --git a/backend/config/settings.py b/backend/config/settings.py index 7a79c537..9ffac2f0 100644 --- a/backend/config/settings.py +++ b/backend/config/settings.py @@ -1,10 +1,10 @@ from pathlib import Path import environ +import helusers.defaults from dateutil.relativedelta import relativedelta from django.utils.log import DEFAULT_LOGGING from django.utils.translation import gettext_lazy -from helusers import defaults from rest_framework.authentication import TokenAuthentication # ----- ENV Setup -------------------------------------------------------------------------------------- @@ -273,7 +273,10 @@ def relativedelta_months(value: int) -> relativedelta: SOCIAL_AUTH_TUNNISTAMO_SECRET = env("SOCIAL_AUTH_TUNNISTAMO_SECRET") SOCIAL_AUTH_TUNNISTAMO_AUTH_EXTRA_ARGUMENTS = {"ui_locales": "fi"} SOCIAL_AUTH_TUNNISTAMO_ALLOWED_REDIRECT_HOSTS = env("SOCIAL_AUTH_TUNNISTAMO_ALLOWED_REDIRECT_HOSTS") -SOCIAL_AUTH_TUNNISTAMO_PIPELINE = defaults.SOCIAL_AUTH_PIPELINE +SOCIAL_AUTH_TUNNISTAMO_PIPELINE = ( + *helusers.defaults.SOCIAL_AUTH_PIPELINE, + "hitas.helauth.pipelines.migrate_user_from_tunnistamo_to_tunnistus", +) HELUSERS_PASSWORD_LOGIN_DISABLED = False HELUSERS_BACK_CHANNEL_LOGOUT_ENABLED = False diff --git a/backend/hitas/helauth/pipelines.py b/backend/hitas/helauth/pipelines.py new file mode 100644 index 00000000..7c5abd6f --- /dev/null +++ b/backend/hitas/helauth/pipelines.py @@ -0,0 +1,44 @@ +from typing import Any, Unpack + +from django.core.handlers.wsgi import WSGIRequest +from helusers.tunnistamo_oidc import TunnistamoOIDCAuth +from social_django.models import UserSocialAuth + +from hitas.helauth.types import ExtraKwargs, IDToken, OIDCResponse +from users.models import User + + +def migrate_user_from_tunnistamo_to_tunnistus( + backend: TunnistamoOIDCAuth, + request: WSGIRequest, + response: OIDCResponse, + user: User | None = None, + **kwargs: Unpack[ExtraKwargs], +) -> dict[str, Any]: + if user is None: + return {"user": user} + id_token = IDToken.from_string(response["id_token"]) + if ( + id_token is not None + # Token issued by helsinki-tunnistus + and id_token.iss.endswith("helsinki-tunnistus") + and id_token.is_ad_login + and id_token.email not in ("", None) + ): + old_user = User.objects.filter(email=id_token.email).exclude(pk=user.pk).first() + if old_user is None: + return {"user": user} + new_user = user + # Delete the old UserSocialAuth object to prevent conflicts + UserSocialAuth.objects.filter(user=old_user).delete() + # Assign the new UserSocialAuth to the old user + UserSocialAuth.objects.filter(user=new_user).update(user=old_user) + # Delete the new User object because we want to keep the old User object and its pk and data + new_user.delete() + # Update the old user to match the new user for fields that are used to uniquely identify a user + old_user.uuid = new_user.uuid + old_user.username = new_user.username + old_user.save() + # Pass the old User object along the authentication pipeline + user = old_user + return {"user": user} diff --git a/backend/hitas/helauth/types.py b/backend/hitas/helauth/types.py new file mode 100644 index 00000000..f11a51c5 --- /dev/null +++ b/backend/hitas/helauth/types.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass +from typing import Literal, TypedDict + +from social_django.models import DjangoStorage, UserSocialAuth +from social_django.strategy import DjangoStrategy + +from hitas.helauth.utils import get_jwt_payload + + +class ADLoginAMR(enum.Enum): + HELSINKI_ADFS = "helsinki_adfs" + HELSINKIAD = "helsinkiad" + HELSINKIAZUREAD = "helsinkiazuread" + EDUAD = "eduad" + + +class ProfileLoginAMR(enum.Enum): + SUOMI_FI = "suomi_fi" + HELTUNNISTUSSUOMIFI = "heltunnistussuomifi" + + +@dataclass +class IDToken: + iss: str + """token issuer: tunnistamo url""" + sub: str + """token subject: uuid""" + aud: str + """token audience: tilavaraus-{env}""" + jti: str + """JWT ID: uuid""" + typ: Literal["ID"] + """token type: ID""" + exp: int + """token expiration date: unix epoch timestamp""" + iat: int + """token issued-at: unix epoch timestamp""" + auth_time: int + """when end-user auth occurred: unix epoch timestamp""" + nonce: str + """random string""" + at_hash: str + """access token hash: sha256""" + name: str + """user name""" + preferred_username: str + """user preferred username""" + given_name: str + """user given name""" + family_name: str + """user family name""" + email: str + """user email""" + email_verified: bool + """Whether the is email verified or not""" + ad_groups: list[str] + """list of ad groups the user belongs to""" + azp: str + """authorized party: tilavaraus-{env}""" + sid: str + """session id: uuid""" + session_state: str + """session state: uuid""" + amr: str | list[str] + """ + authentication methods reference: + suomi_fi | heltunnistussuomifi | helsinki_adfs | helsinkiad | helsinkiazuread | eduad + """ + loa: Literal["substantial", "low"] + """level of authentication""" + + @classmethod + def from_string(cls, token: str) -> IDToken | None: + try: + payload = get_jwt_payload(token) + except Exception: + return None + + return cls( + iss=payload["iss"], + sub=payload["sub"], + aud=payload["aud"], + jti=payload["jti"], + typ=payload.get("typ", ""), # type: ignore[arg-type] + exp=payload["exp"], + iat=payload["iat"], + auth_time=payload["auth_time"], + nonce=payload.get("nonce", ""), + at_hash=payload.get("at_hash", ""), + name=payload.get("name", ""), + preferred_username=payload.get("preferred_username", ""), + given_name=payload.get("given_name", ""), + family_name=payload.get("family_name", ""), + email=payload.get("email", ""), + email_verified=payload.get("email_verified", False), + ad_groups=payload.get("ad_groups", []), + azp=payload.get("azp", ""), + sid=payload.get("sid", ""), + session_state=payload.get("session_state", ""), + amr=payload["amr"], + loa=payload["loa"], + ) + + @property + def is_ad_login(self) -> bool: + amr = self.amr + if amr is None: + return False + + if isinstance(amr, str): + amr = [amr] + return any(method.value in amr for method in ADLoginAMR) + + @property + def is_profile_login(self) -> bool: + amr = self.amr + if amr is None: + return False + + if isinstance(amr, str): + amr = [amr] + return any(method.value in amr for method in ProfileLoginAMR) + + @property + def is_strong_login(self) -> bool: + return self.loa == "substantial" + + +class UserDetails(TypedDict): + email: str + first_name: str | None + last_name: str | None + fullname: str | None + username: str | None + + +class OIDCResponse(TypedDict): + access_token: str + email: str + email_verified: bool + expires_in: int + id_token: str + nickname: str + refresh_token: str + sub: str + token_type: str + + +class ExtraKwargs(TypedDict): + details: UserDetails + is_new: bool + new_association: bool + pipeline_index: int + social: UserSocialAuth + storage: DjangoStorage + strategy: DjangoStrategy + uid: str + username: str diff --git a/backend/hitas/helauth/utils.py b/backend/hitas/helauth/utils.py new file mode 100644 index 00000000..efad2304 --- /dev/null +++ b/backend/hitas/helauth/utils.py @@ -0,0 +1,11 @@ +import base64 +import json +from typing import Any + + +def get_jwt_payload(json_web_token: str) -> dict[str, Any]: + jwt_header_part, jwt_payload_part, jwt_signature_part = json_web_token.split(".") + # Add padding to the payload if needed + jwt_payload_part += "=" * divmod(len(jwt_payload_part), 4)[1] + payload_json: str = base64.urlsafe_b64decode(jwt_payload_part).decode() + return json.loads(payload_json) diff --git a/backend/hitas/tests/test_auth.py b/backend/hitas/tests/test_auth.py new file mode 100644 index 00000000..02c2d3b7 --- /dev/null +++ b/backend/hitas/tests/test_auth.py @@ -0,0 +1,77 @@ +import base64 +import json +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from social_django.models import UserSocialAuth + +from hitas.helauth.pipelines import migrate_user_from_tunnistamo_to_tunnistus +from hitas.helauth.types import IDToken +from hitas.tests.factories import UserFactory +from users.models import User + + +@pytest.mark.django_db +def test__migrate_user_from_tunnistamo_to_tunnistus__existing_tunnistamo_user(): + old_user = UserFactory.create(email="test.user@mail.test", username="old-test-user") + new_user = UserFactory.create(email="test.user@mail.test", username="new-test-user") + old_social_uid = str(uuid.uuid4()) + new_social_uid = str(uuid.uuid4()) + UserSocialAuth.objects.create(user=old_user, provider="tunnistamo", uid=old_social_uid) + UserSocialAuth.objects.create(user=new_user, provider="tunnistamo", uid=new_social_uid) + # First login of a not-yet-migrated user + with patch("hitas.helauth.pipelines.IDToken.from_string", MagicMock()) as id_token_mock: + id_token_mock.return_value.email = "test.user@mail.test" + migrate_user_from_tunnistamo_to_tunnistus(None, None, {"id_token": None}, new_user) + assert User.objects.count() == 1, "There should be only one user after migration." + assert User.objects.filter(pk=old_user.pk).exists(), "The old user should exist." + assert User.objects.filter(username="new-test-user").exists(), "The old user should have the new username." + assert UserSocialAuth.objects.count() == 1, "There should be only one UserSocialAuth after migration." + assert UserSocialAuth.objects.filter(uid=new_social_uid).exists(), "The new UserSocialAuth should exist." + # Second login after initial migration + user_logging_in = User.objects.get(email="test.user@mail.test") + with patch("hitas.helauth.pipelines.IDToken.from_string", MagicMock()) as id_token_mock: + id_token_mock.return_value.email = "test.user@mail.test" + result = migrate_user_from_tunnistamo_to_tunnistus(None, None, {"id_token": None}, user_logging_in) + assert ( + result["user"] is user_logging_in + ), "The second login should return the user through the migration unaffected." + + +@pytest.mark.django_db +def test__migrate_user_from_tunnistamo_to_tunnistus__user_is_none(): + result = migrate_user_from_tunnistamo_to_tunnistus(None, None, {"id_token": None}, None) + assert result == {"user": None} + + +@pytest.mark.django_db +def test__migrate_user_from_tunnistamo_to_tunnistus__id_token_is_none(): + user = User() + result = migrate_user_from_tunnistamo_to_tunnistus(None, None, {"id_token": None}, user) + assert result["user"] == user + + +@pytest.mark.django_db +def test__IDToken_from_string(): + payload = { + "iss": "test", + "sub": "test", + "aud": "test", + "jti": "test", + "exp": 1, + "iat": 1, + "auth_time": 1, + "amr": "test", + "loa": "low", + } + payload_json = json.dumps(payload) + jwt_header_part = "" + jwt_payload_part = base64.urlsafe_b64encode(payload_json.encode("utf-8")).decode("utf-8").rstrip("=") + jwt_signature_part = "" + id_token_string = f"{jwt_header_part}.{jwt_payload_part}.{jwt_signature_part}" + id_token = IDToken.from_string(id_token_string) + assert id_token.iss == "test", "The IDToken should have the correct issuer." + assert id_token.is_ad_login is False, "The IDToken should not be an AD login." + assert id_token.is_profile_login is False, "The IDToken should not be a Helsinki profile login." + assert id_token.is_strong_login is False, "The IDToken should not be strongly authenticated."