Skip to content

Commit

Permalink
Add auth module and create_task authz
Browse files Browse the repository at this point in the history
  • Loading branch information
paulineribeyre committed Oct 8, 2024
1 parent 86d5c0b commit 8a1c81b
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 49 deletions.
20 changes: 19 additions & 1 deletion gen3workflow/app.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from fastapi import FastAPI
import httpx
from importlib.metadata import version
import os

from cdislogging import get_logger
from gen3authz.client.arborist.async_client import ArboristClient

from gen3workflow import logger
from gen3workflow.config import config
Expand All @@ -15,6 +17,8 @@ def get_app(httpx_client=None) -> FastAPI:
config.validate()

debug = config["DEBUG"]
log_level = "debug" if debug else "info"

app = FastAPI(
title="Gen3Workflow",
version=version("gen3workflow"),
Expand All @@ -26,7 +30,21 @@ def get_app(httpx_client=None) -> FastAPI:
app.include_router(ga4gh_tes_router, tags=["GA4GH TES"])

# Following will update logger level, propagate, and handlers
get_logger("gen3workflow", log_level="debug" if debug == True else "info")
get_logger("gen3workflow", log_level=log_level)

logger.info("Initializing Arborist client")
custom_arborist_url = os.environ.get("ARBORIST_URL", config["ARBORIST_URL"])
if custom_arborist_url:
app.arborist_client = ArboristClient(
arborist_base_url=custom_arborist_url,
authz_provider="gen3-workflow",
logger=get_logger("gen3workflow.gen3authz", log_level=log_level),
)
else:
app.arborist_client = ArboristClient(
authz_provider="gen3-workflow",
logger=get_logger("gen3workflow.gen3authz", log_level=log_level),
)

return app

Expand Down
82 changes: 82 additions & 0 deletions gen3workflow/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from authutils.token.fastapi import access_token
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

from gen3authz.client.arborist.errors import ArboristError

from gen3workflow import logger


# auto_error=False prevents FastAPI from raising a 403 when the request
# is missing an Authorization header. Instead, we want to return a 401
# to signify that we did not receive valid credentials
bearer = HTTPBearer(auto_error=False)


class Auth:
def __init__(
self,
api_request: Request,
bearer_token: HTTPAuthorizationCredentials = Security(bearer),
):
self.arborist_client = api_request.app.arborist_client
self.bearer_token = bearer_token

async def get_token_claims(self) -> dict:
if not self.bearer_token:
err_msg = "Must provide an access token."
logger.error(err_msg)
raise HTTPException(
HTTP_401_UNAUTHORIZED,
err_msg,
)

try:
token_claims = await access_token(
"user", "openid", audience="openid", purpose="access"
)(self.bearer_token)
except Exception as e:
logger.error(
f"Could not get token claims:\n{e.detail if hasattr(e, 'detail') else e}",
exc_info=True,
)
raise HTTPException(
HTTP_401_UNAUTHORIZED,
"Could not verify, parse, and/or validate scope from provided access token.",
)

return token_claims

async def authorize(
self,
method: str,
resources: list,
throw: bool = True,
) -> bool:
token = (
self.bearer_token.credentials
if self.bearer_token and hasattr(self.bearer_token, "credentials")
else None
)

try:
authorized = await self.arborist_client.auth_request(
token, "gen3-workflow", method, resources
)
except ArboristError as e:
logger.error(f"Error while talking to arborist: {e}")
authorized = False

if not authorized:
logger.error(
f"Authorization error: token must have '{method}' access on {resources} for service 'gen3-workflow'."
)
if throw:
raise HTTPException(
HTTP_403_FORBIDDEN,
"Permission denied",
)

return authorized
3 changes: 3 additions & 0 deletions gen3workflow/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
DEBUG: true
DOCS_URL_PREFIX: /gen3workflow

# override the default Arborist URL; ignored if already set as an environment variable
ARBORIST_URL:

####################
# GA4GH TES #
####################
Expand Down
18 changes: 17 additions & 1 deletion gen3workflow/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from jsonschema import validate

from gen3config import Config

from . import logger
Expand All @@ -20,7 +22,21 @@ def validate(self) -> None:
Perform a series of sanity checks on a loaded config.
"""
logger.info("Validating configuration")
# will do more here when there is more config
self.validate_top_level_configs()

def validate_top_level_configs(self):
schema = {
"type": "object",
"additionalProperties": True,
"properties": {
"DEBUG": {"type": "boolean"},
"DOCS_URL_PREFIX": {"type": "string"},
"ARBORIST_URL": {"type": ["string", "null"]},
"TES_SERVER_URL": {"type": "string"},
},
}
validate(instance=self, schema=schema)


config = Gen3WorkflowConfig(DEFAULT_CFG_PATH)
try:
Expand Down
16 changes: 13 additions & 3 deletions gen3workflow/routes/ga4gh_tes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import json

from fastapi import APIRouter, HTTPException, Request
from starlette.status import HTTP_200_OK
from fastapi import APIRouter, Depends, HTTPException, Request
from starlette.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED

from gen3workflow.auth import Auth
from gen3workflow.config import config


Expand All @@ -35,8 +36,17 @@ async def service_info(request: Request):


@router.post("/tasks", status_code=HTTP_200_OK)
async def create_task(request: Request):
async def create_task(request: Request, auth=Depends(Auth)):
await auth.authorize("create", ["services/workflow/gen3-workflow/task"])
body = await get_request_body(request)

# add the USER_ID tag to the task
if "tags" not in body:
body["tags"] = {}
body["tags"]["USER_ID"] = (await auth.get_token_claims()).get("sub")
if not body["tags"]["USER_ID"]:
raise HTTPException(HTTP_401_UNAUTHORIZED, "No user sub in token")

res = await request.app.async_client.post(
f"{config['TES_SERVER_URL']}/tasks", json=body
)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ cdislogging = "<2"
fastapi = "<1"
gen3authz = "<3"
gen3config = "<2"
gunicorn = "<24"
httpx = "<1"
jsonschema = "<5"
uvicorn = "<1"
gunicorn = "<24"

[tool.poetry.dev-dependencies]
pytest = "<9"
Expand Down
94 changes: 83 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
from urllib.parse import urlparse

from fastapi import Request
import httpx
import pytest
import pytest_asyncio
from starlette.config import environ

Expand All @@ -17,6 +18,32 @@
from gen3workflow.config import config


TEST_USER_ID = "64"


@pytest.fixture(scope="function")
def access_token_patcher(client, request):
"""
The `access_token` function will return a token linked to a test user.
This fixture should be used explicitely instead of the automatic
`access_token_user_client_patcher` fixture for endpoints that do not
support client tokens.
"""

async def get_access_token(*args, **kwargs):
return {"sub": TEST_USER_ID}

access_token_mock = MagicMock()
access_token_mock.return_value = get_access_token

access_token_patch = patch("gen3workflow.auth.access_token", access_token_mock)
access_token_patch.start()

yield access_token_mock

access_token_patch.stop()


def mock_tes_server_request_function(
method: str, path: str, query_params: str, body: str, status_code: int
):
Expand Down Expand Up @@ -62,39 +89,84 @@ async def reset_mock_tes_server_request():
mock_tes_server_request.reset_mock()


def mock_arborist_request(
method: str, url: str, authorized: bool
):
# URLs to reponses: { URL: { METHOD: response body } }
urls_to_responses = {
"http://test-arborist-server/auth/request": {
"POST": {"auth": authorized}
},
}

text, body = None, None
if url not in urls_to_responses:
print(
f"Unable to mock Arborist request: '{url}' is not in `urls_to_responses`."
)
status_code = 404
text = "NOT FOUND"
elif method not in urls_to_responses[url]:
status_code = 405
text = "METHOD NOT ALLOWED"
else:
content = urls_to_responses[url][method]
status_code = 200
if isinstance(content, dict):
body = content
else:
text = content

return httpx.Response(status_code=status_code, json=body, text=text)


@pytest_asyncio.fixture(scope="session")
async def client(request):
"""
Requests made by the tests to the app use a real HTTPX client.
Requests made by the app to external mocked services (such as Funnel) use a mocked client.
"""
status_code = 200
tes_resp_code = 200
authorized = True
if hasattr(request, "param"):
status_code = request.param.get("status_code", 200)
tes_resp_code = request.param.get("tes_resp_code", 200)
authorized = request.param.get("authorized", True)

async def handle_request(request: Request):
url = str(request.url)

parsed_url = urlparse(url)
mocked_response = None
if url.startswith(config["TES_SERVER_URL"]):
print(
f"Mocking request '{request.method} {url}' to return code {status_code}"
)
parsed_url = urlparse(url)
path = url[len(config["TES_SERVER_URL"]) :].split("?")[0]
return mock_tes_server_request(
mocked_response = mock_tes_server_request(
method=request.method,
path=path,
query_params=parsed_url.query,
body=request.content.decode(),
status_code=status_code,
status_code=tes_resp_code,
)
elif url.startswith(config["ARBORIST_URL"]):
mocked_response = mock_arborist_request(
method=request.method,
url=url,
authorized=authorized,
)

if mocked_response is not None:
print(
f"Mocking request '{request.method} {url}' to return code {tes_resp_code}"
)
return mocked_response
else:
print(f"Not mocking request '{request.method} {url}'")
httpx_client_function = getattr(httpx.AsyncClient(), request.method.lower())
return await httpx_client_function(url)

mock_httpx_client = httpx.AsyncClient(transport=httpx.MockTransport(handle_request))
app = get_app(httpx_client=mock_httpx_client)
app.arborist_client.client_cls = lambda: httpx.AsyncClient(transport=httpx.MockTransport(handle_request))
async with httpx.AsyncClient(app=app, base_url="http://test-gen3-wf") as real_httpx_client:
real_httpx_client.status_code = status_code # for easier access to the param in the tests
# for easier access to the param in the tests
real_httpx_client.tes_resp_code = tes_resp_code
real_httpx_client.authorized = authorized
yield real_httpx_client
1 change: 1 addition & 0 deletions tests/test-gen3workflow-config.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
TES_SERVER_URL: http://external-tes-server/tes
ARBORIST_URL: http://test-arborist-server
Loading

0 comments on commit 8a1c81b

Please sign in to comment.