Skip to content

Commit

Permalink
refactor oidc handler and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gantoine committed Dec 18, 2024
1 parent 1a97356 commit 37db255
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 63 deletions.
2 changes: 1 addition & 1 deletion backend/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def str_to_bool(value: str) -> bool:
)

# OIDC
OIDC_ENABLED: Final = os.environ.get("OIDC_ENABLED", "false") == "true"
OIDC_ENABLED: Final = str_to_bool(os.environ.get("OIDC_ENABLED", "false"))
OIDC_CLIENT_ID: Final = os.environ.get("OIDC_CLIENT_ID", "")
OIDC_CLIENT_SECRET: Final = os.environ.get("OIDC_CLIENT_SECRET", "")
OIDC_REDIRECT_URI: Final = os.environ.get("OIDC_REDIRECT_URI", "")
Expand Down
119 changes: 86 additions & 33 deletions backend/handler/auth/base_handler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import asyncio
import enum
from datetime import datetime, timedelta, timezone
from typing import Any, Final
from typing import Any, Final, Optional

import httpx
from config import OIDC_ENABLED, OIDC_SERVER_APPLICATION_URL, ROMM_AUTH_SECRET_KEY
from exceptions.auth_exceptions import OAuthCredentialsException
from exceptions.auth_exceptions import OAuthCredentialsException, UserDisabledException
from fastapi import HTTPException, status
from joserfc import jwt
from joserfc.errors import BadSignatureError
from joserfc.errors import BadSignatureError, ExpiredTokenError, InvalidPayloadError
from joserfc.jwk import OctKey, RSAKey
from logger.logger import log
from passlib.context import CryptContext
from starlette.requests import HTTPConnection
from utils.context import ctx_httpx_client

ALGORITHM: Final = "HS256"
DEFAULT_OAUTH_TOKEN_EXPIRY: Final = timedelta(minutes=15)
Expand Down Expand Up @@ -140,7 +142,7 @@ async def get_current_active_user_from_bearer_token(self, token: str):

issuer = payload.claims.get("iss")
if not issuer or issuer != "romm:oauth":
return None
return None, None

username = payload.claims.get("sub")
if username is None:
Expand All @@ -151,31 +153,68 @@ async def get_current_active_user_from_bearer_token(self, token: str):
raise OAuthCredentialsException

if not user.enabled:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
)
raise UserDisabledException

return user, payload.claims


class RSAKeyNotFoundError(Exception): ...


class OpenIDHandler:
def __init__(self) -> None:
if not OIDC_ENABLED:
return
RSA_ALGORITHM = "RS256"

# Fetch the public key from the OIDC server
# JWKS (JSON Web Key Sets) response is a JSON object with a keys array
def __init__(self) -> None:
self._rsa_key: Optional[RSAKey] = None
self._rsa_key_lock = asyncio.Lock()

async def _fetch_rsa_key(self) -> RSAKey:
"""
Fetch the public key from the OIDC server
JWKS (JSON Web Key Sets) response is a JSON object with a keys array
"""
jwks_url = f"{OIDC_SERVER_APPLICATION_URL}/jwks/"
with httpx.Client() as httpx_client:
try:
response = httpx_client.get(jwks_url, timeout=120)
key = response.json()["keys"][0]
self.rsa_key = RSAKey.import_key(key)
except httpx.HTTPStatusError as exc:
raise HTTPException(
status_code=exc.response.status_code,
detail=exc.response.text,
) from exc
log.debug("Fetching JWKS from %s", jwks_url)

httpx_client = ctx_httpx_client.get()
try:
response = await httpx_client.get(jwks_url, timeout=120)
response.raise_for_status()
keys = response.json().get("keys", [])
if not keys:
raise RSAKeyNotFoundError("No RSA keys found in JWKS response.")

return RSAKey.import_key(keys[0])
except (httpx.RequestError, KeyError, RSAKeyNotFoundError) as exc:
log.error("Unable to fetch RSA public key: %s", str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch RSA public key",
) from exc

async def get_rsa_key(self) -> RSAKey:
"""
Retrieves the cached RSA public key, or fetches it if not already cached.
"""
if not self._rsa_key:
async with self._rsa_key_lock:
if not self._rsa_key: # Double-check in case of concurrent calls
self._rsa_key = await self._fetch_rsa_key()
return self._rsa_key

async def validate_token(self, token: str) -> jwt.Token:
"""
Validates a JWT token using the RSA public key.
"""
try:
rsa_key = await self.get_rsa_key()
return jwt.decode(token, rsa_key, algorithms=[self.RSA_ALGORITHM])
except (BadSignatureError, ExpiredTokenError, InvalidPayloadError) as exc:
log.error("Token validation failed: %s", str(exc))
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
) from exc

async def get_current_active_user_from_openid_token(self, token: Any):
from handler.database import db_user_handler
Expand All @@ -184,27 +223,41 @@ async def get_current_active_user_from_openid_token(self, token: Any):
return None, None

id_token = token.get("id_token")
if not id_token:
log.error("ID Token is missing from token.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="ID Token is missing from token.",
)

try:
payload = jwt.decode(id_token, self.rsa_key, algorithms=["RS256"])
except (BadSignatureError, ValueError) as exc:
raise OAuthCredentialsException from exc
payload = await self.validate_token(id_token)

iss = payload.claims.get("iss")
if OIDC_SERVER_APPLICATION_URL not in str(iss):
raise OAuthCredentialsException
if not iss or OIDC_SERVER_APPLICATION_URL not in str(iss):
log.error("Invalid issuer in token: %s", iss)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid issuer in token.",
)

email = payload.claims.get("email")
if email is None:
raise OAuthCredentialsException
log.error("Email is missing from token.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is missing from token.",
)

user = db_user_handler.get_user_by_email(email)
if user is None:
raise OAuthCredentialsException

if not user.enabled:
log.error("User with email '%s' not found", email)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user"
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)

if not user.enabled:
raise UserDisabledException

log.info("User successfully authenticated: %s", email)
return user, payload.claims
4 changes: 3 additions & 1 deletion backend/handler/auth/tests/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ async def test_get_current_active_user_from_bearer_token(admin_user):
},
)
user, claims = await oauth_handler.get_current_active_user_from_bearer_token(token)
if not user or not claims:
pytest.fail("User or claims not found")

assert user.id == admin_user.id
assert claims["sub"] == admin_user.username
Expand Down Expand Up @@ -61,7 +63,7 @@ async def test_get_current_active_user_from_bearer_token_disabled_user(admin_use
await oauth_handler.get_current_active_user_from_bearer_token(token)
except HTTPException as e:
assert e.status_code == 401
assert e.detail == "Inactive user"
assert e.detail == "Disabled user"


def test_protected_route():
Expand Down
68 changes: 40 additions & 28 deletions backend/handler/auth/tests/test_oidc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest
from fastapi import HTTPException
from handler.auth.base_handler import OpenIDHandler
from httpx import HTTPStatusError, Request, Response
from handler.auth.base_handler import OpenIDHandler, ctx_httpx_client
from httpx import Request, RequestError, Response
from joserfc.errors import BadSignatureError
from joserfc.jwt import Token

Expand All @@ -26,37 +26,48 @@ def mock_oidc_enabled(mocker):
mocker.patch("handler.auth.base_handler.OIDC_ENABLED", True)


@pytest.fixture
def mock_httpx_client():
"""Fixture to mock the httpx.AsyncClient and set it in the ContextVar."""
mock_client = AsyncMock()
token = ctx_httpx_client.set(mock_client)
yield mock_client
ctx_httpx_client.reset(token)


@pytest.fixture
def mock_request():
return Request("GET", f"{OIDC_SERVER_APPLICATION_URL}/jwks/")


def test_oidc_disabled_initialization(mock_oidc_disabled):
"""Test that the handler initializes correctly when OIDC is disabled."""
oidc_handler = OpenIDHandler()
assert not hasattr(oidc_handler, "rsa_key")
assert oidc_handler._rsa_key is None


def test_oidc_enabled_server_unreachable(mocker, mock_oidc_enabled):
async def test_oidc_enabled_server_unreachable(
mock_httpx_client, mock_request, mock_oidc_enabled
):
"""Test that initialization raises an HTTPException when the OIDC server is unreachable."""
# Mock request and response
mock_request = Request("GET", f"{OIDC_SERVER_APPLICATION_URL}/jwks/")
mock_response = Response(500, request=mock_request)

# Mock the HTTPStatusError
mocker.patch(
"httpx.Client.get",
side_effect=HTTPStatusError(
"Mocked error", request=mock_request, response=mock_response
),
mock_httpx_client.get.side_effect = RequestError(
"Mocked error", request=mock_request
)

oidc_handler = OpenIDHandler()
token = {"id_token": "invalid_signature_token"}
with pytest.raises(HTTPException):
OpenIDHandler()
await oidc_handler.get_current_active_user_from_openid_token(token)


async def test_oidc_valid_token_decoding(mocker, mock_oidc_enabled):
async def test_oidc_valid_token_decoding(
mocker, mock_httpx_client, mock_request, mock_oidc_enabled
):
"""Test token decoding with valid RSA key and token."""
mocker.patch(
"httpx.Client.get",
return_value=MagicMock(
json=lambda: {"keys": [{"kty": "RSA", "n": "fake", "e": "AQAB"}]}
),
mock_httpx_client.get.return_value = Response(
200,
request=mock_request,
json={"keys": [{"kty": "RSA", "n": "fake", "e": "AQAB"}]},
)
mock_rsa_key = MagicMock()
mocker.patch(
Expand All @@ -80,13 +91,14 @@ async def test_oidc_valid_token_decoding(mocker, mock_oidc_enabled):
assert claims == mock_jwt_payload.claims


async def test_oidc_invalid_token_signature(mocker, mock_oidc_enabled):
async def test_oidc_invalid_token_signature(
mocker, mock_httpx_client, mock_request, mock_oidc_enabled
):
"""Test token decoding raises exception for invalid signature."""
mocker.patch(
"httpx.Client.get",
return_value=MagicMock(
json=lambda: {"keys": [{"kty": "RSA", "n": "fake", "e": "AQAB"}]}
),
mock_httpx_client.get.return_value = Response(
200,
request=mock_request,
json={"keys": [{"kty": "RSA", "n": "fake", "e": "AQAB"}]},
)
mock_rsa_key = MagicMock()
mocker.patch(
Expand Down

0 comments on commit 37db255

Please sign in to comment.