diff --git a/tests/test_auth.py b/tests/test_auth.py index 00044d8..544a353 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,7 @@ @pytest.fixture -def oidc_token(app): +def oidc_auth_profile(app): with app.test_request_context('/api/v1.0/waivers/new'): with mock.patch.dict(session, {'oidc_auth_profile': { 'active': True, @@ -36,6 +36,18 @@ def oidc_token(app): yield mocked['oidc_auth_profile'] +@pytest.fixture +def oidc_token(app): + with app.test_request_context('/api/v1.0/waivers/new'): + with mock.patch.dict(session, {'oidc_auth_token': { + 'active': True, + 'username': 'testuser', + 'preferred_username': 'testuser', + 'scope': 'openid waiverdb_scope', + }, 'oidc_auth_profile': {}}) as mocked: + yield mocked + + @pytest.fixture def verify_authorization(): with mock.patch("waiverdb.api_v1.verify_authorization") as mocked: @@ -93,15 +105,19 @@ def test_get_user_no_auth_methods(self): waiverdb.auth.get_user(request) assert "Authenticated user required. No methods specified." in str(excinfo.value) - def test_get_user_without_token(self, app): + def test_get_user_without_profile(self, app): with app.test_request_context('/api/v1.0/waivers/new'): with pytest.raises(Unauthorized) as excinfo: waiverdb.auth.get_user(request) assert self.auth_missing_error in str(excinfo.value) - def test_get_user_good(self, oidc_token): + def test_get_user_good_profile(self, oidc_auth_profile): + user, header = waiverdb.auth.get_user(request) + assert user == oidc_auth_profile["preferred_username"] + + def test_get_user_good_token(self, oidc_token): user, header = waiverdb.auth.get_user(request) - assert user == oidc_token["username"] + assert user == oidc_token["preferred_username"] # tests only redirect of deprecated resource # not working, causing an exception in flask_oidc library: @@ -111,7 +127,7 @@ def test_create_new_waiver( self, verify_authorization, permissions, - oidc_token, + oidc_auth_profile, client, ): verify_authorization.return_value = True diff --git a/waiverdb/auth.py b/waiverdb/auth.py index c7e4961..7ce3a62 100644 --- a/waiverdb/auth.py +++ b/waiverdb/auth.py @@ -4,6 +4,7 @@ import base64 import binascii import gssapi +from authlib.integrations.flask_oauth2 import current_token from flask import current_app, Request, Response, session from werkzeug.exceptions import Unauthorized, Forbidden @@ -60,13 +61,19 @@ def get_user(request: Request) -> tuple[str, dict[str, str]]: def get_oidc_userinfo(field: str) -> str: - fields = session.get("oidc_auth_profile", {}) - if field not in fields: - current_app.logger.error( - "User info field %r is unavailable; available are: %s", field, fields.keys() - ) - raise Unauthorized("Failed to retrieve username") - return fields[field] + pfields = session.get("oidc_auth_profile", {}) + if field in pfields: + return pfields[field] + tfields = session.get("oidc_auth_token", {}) + if field in tfields: + return tfields[field] + current_app.logger.error( + "User info field %r is unavailable; available are: %s (auth profile), %s (token)", + field, + pfields.keys(), + tfields.keys(), + ) + raise Unauthorized("Failed to retrieve username") def get_user_by_method(request: Request, auth_method: str) -> tuple[str, dict[str, str]]: