Skip to content

Commit

Permalink
Add USER_ID tag to list_tasks calls
Browse files Browse the repository at this point in the history
  • Loading branch information
paulineribeyre committed Oct 8, 2024
1 parent 13ddb83 commit 369f90f
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 19 deletions.
50 changes: 45 additions & 5 deletions gen3workflow/routes/ga4gh_tes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
https://editor.swagger.io/?url=https://raw.githubusercontent.com/ga4gh/task-execution-schemas/develop/openapi/task_execution_service.openapi.yaml
"""

from collections import defaultdict
import json

from fastapi import APIRouter, Depends, HTTPException, Request
from starlette.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED
from starlette.datastructures import QueryParams
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED

from gen3workflow.auth import Auth
from gen3workflow.config import config
Expand Down Expand Up @@ -55,8 +57,45 @@ async def create_task(request: Request, auth=Depends(Auth)):
return res.json()


def generate_list_tasks_query_params(
original_query_params: QueryParams,
supported_params: list,
user_id: str,
):
"""
The `tag_key` and `tag_value` params support setting multiple values, for example:
`?tag_key=tagA&tag_value=valueA&tag_key=tagB&tag_value=valueB` means that that tasks
are filtered on: `tagA == valueA and tagB == valueB`.
We need to maintain this support, as well as add the `USER_ID` tag so users can only
list their own tasks.
"""
# Convert the query params to a data struct that's easier to work with:
# [(tag_key, tagA), (tag_value, valueA), (tag_key, tagB), (tag_value, valueB)]
# becomes {tag_key: [tagA, tagB], tag_value: [valueA, valueB]}
query_params = defaultdict(list)
for k, v in original_query_params.multi_items():
if k in supported_params: # filter out any unsupported params
query_params[k].append(v)

if len(query_params["tag_key"]) != len(query_params["tag_value"]):
raise Exception(
HTTP_400_BAD_REQUEST, "Parameters `tag_key` and `tag_value` mismatch"
)

# Check if there is already a `USER_ID` tag. If so, its value must be replaced. If not, add one.
try:
user_id_tag_index = query_params.get("tag_key", []).index("USER_ID")
except ValueError:
query_params["tag_key"].append("USER_ID")
query_params["tag_value"].append(user_id)
else:
query_params["tag_value"][user_id_tag_index] = user_id

return query_params


@router.get("/tasks", status_code=HTTP_200_OK)
async def list_tasks(request: Request):
async def list_tasks(request: Request, auth=Depends(Auth)):
supported_params = {
"name_prefix",
"state",
Expand All @@ -66,9 +105,10 @@ async def list_tasks(request: Request):
"page_token",
"view",
}
query_params = {
k: v for k, v in dict(request.query_params).items() if k in supported_params
}
user_id = (await auth.get_token_claims()).get("sub")
query_params = generate_list_tasks_query_params(
request.query_params, supported_params, user_id
)
res = await request.app.async_client.get(
f"{config['TES_SERVER_URL']}/tasks", params=query_params
)
Expand Down
16 changes: 8 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,10 @@ async def reset_mock_tes_server_request():
mock_tes_server_request.reset_mock()


def mock_arborist_request(
method: str, url: str, authorized: bool
):
def mock_arborist_request(method: str, url: str, authorized: bool):
# URLs to reponses: { URL: { METHOD: response body } }
urls_to_responses = {
"http://test-arborist-server/auth/request": {
"POST": {"auth": authorized}
},
"http://test-arborist-server/auth/request": {"POST": {"auth": authorized}},
}

text, body = None, None
Expand Down Expand Up @@ -164,8 +160,12 @@ async def handle_request(request: Request):

mock_httpx_client = httpx.AsyncClient(transport=httpx.MockTransport(handle_request))
app = get_app(httpx_client=mock_httpx_client)
app.arborist_client.client_cls = lambda: httpx.AsyncClient(transport=httpx.MockTransport(handle_request))
async with httpx.AsyncClient(app=app, base_url="http://test-gen3-wf") as real_httpx_client:
app.arborist_client.client_cls = lambda: httpx.AsyncClient(
transport=httpx.MockTransport(handle_request)
)
async with httpx.AsyncClient(
app=app, base_url="http://test-gen3-wf"
) as real_httpx_client:
# for easier access to the param in the tests
real_httpx_client.tes_resp_code = tes_resp_code
real_httpx_client.authorized = authorized
Expand Down
70 changes: 64 additions & 6 deletions tests/test_ga4gh_tes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,70 @@ async def test_service_info_endpoint(client):
ids=["success", "failure"],
indirect=True,
)
async def test_list_tasks(client):
async def test_list_tasks(client, access_token_patcher):
"""
Calls to `GET /ga4gh-tes/v1/tasks` should be forwarded to the TES server, and any
unsupported query params should be filtered out.
unsupported query params should be filtered out. The USER_ID tag should be added.
When the TES server returns an error, gen3-workflow should return it as well.
"""
res = await client.get("/ga4gh-tes/v1/tasks?state=COMPLETE&unsupported_param=value")
res = await client.get(
"/ga4gh-tes/v1/tasks?state=COMPLETE&unsupported_param=value",
headers={"Authorization": f"bearer 123"},
)
assert res.status_code == client.tes_resp_code
if client.tes_resp_code == 500:
assert res.json() == {"detail": "TES server error"}
mock_tes_server_request.assert_called_once_with(
method="GET",
path="/tasks",
query_params="state=COMPLETE",
query_params=f"state=COMPLETE&tag_key=USER_ID&tag_value={TEST_USER_ID}",
body="",
status_code=client.tes_resp_code,
)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"tags",
[
{"provided": "", "expected": f"tag_key=USER_ID&tag_value={TEST_USER_ID}"},
{
"provided": "tag_key=foo1&tag_value=bar1",
"expected": f"tag_key=foo1&tag_key=USER_ID&tag_value=bar1&tag_value={TEST_USER_ID}",
},
{
"provided": "tag_key=foo1&tag_value=bar1&tag_key=USER_ID&tag_value=should_be_removed&tag_key=foo2&tag_value=bar2",
"expected": f"tag_key=foo1&tag_key=USER_ID&tag_key=foo2&tag_value=bar1&tag_value={TEST_USER_ID}&tag_value=bar2",
},
{
"provided": "tag_key=foo1&tag_key=USER_ID&tag_value=bar1&tag_value=should_be_removed",
"expected": f"tag_key=foo1&tag_key=USER_ID&tag_value=bar1&tag_value={TEST_USER_ID}",
},
],
ids=[
"no previous tags",
"previous tags without user id",
"previous tags with user id, order 1",
"previous tags with user id, order 2",
],
)
async def test_list_tasks_tag_replacement(client, access_token_patcher, tags):
"""
Check that the USER_ID tag is added or replaced in `GET /ga4gh-tes/v1/tasks` calls before
they are forwarded to the TES server, and that multiple `tag_key` and `tag_value` params are
supported.
"""
res = await client.get(
f"/ga4gh-tes/v1/tasks?{tags['provided']}",
headers={"Authorization": f"bearer 123"},
)
assert res.status_code == client.tes_resp_code
if client.tes_resp_code == 500:
assert res.json() == {"detail": "TES server error"}
mock_tes_server_request.assert_called_once_with(
method="GET",
path=f"/tasks",
query_params=tags["expected"],
body="",
status_code=client.tes_resp_code,
)
Expand Down Expand Up @@ -78,7 +128,11 @@ async def test_get_task(client):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client",
[{"authorized": False}, {"authorized": True, "tes_resp_code": 200}, {"authorized": True, "tes_resp_code": 500}],
[
{"authorized": False},
{"authorized": True, "tes_resp_code": 200},
{"authorized": True, "tes_resp_code": 500},
],
ids=["unauthorized", "success", "failure"],
indirect=True,
)
Expand All @@ -90,7 +144,11 @@ async def test_create_task(client, access_token_patcher):
If the user is not authorized, we should get a 403 error and no TES server requests should
be made.
"""
res = await client.post("/ga4gh-tes/v1/tasks", json={"name": "test-task"}, headers={"Authorization": f"bearer 123"})
res = await client.post(
"/ga4gh-tes/v1/tasks",
json={"name": "test-task"},
headers={"Authorization": f"bearer 123"},
)
if not client.authorized:
assert res.status_code == 403
mock_tes_server_request.assert_not_called()
Expand Down

0 comments on commit 369f90f

Please sign in to comment.