diff --git a/server/core.py b/server/core.py index 4c2cf69..d64f136 100644 --- a/server/core.py +++ b/server/core.py @@ -18,7 +18,14 @@ ) from supertokens_python.asyncio import list_users_by_account_info from supertokens_python.auth_utils import LinkingToSessionUserFailedError -from supertokens_python.recipe import dashboard, emailpassword, session, thirdparty +from supertokens_python.exceptions import GeneralError +from supertokens_python.recipe import ( + dashboard, + emailpassword, + session, + thirdparty, + userroles, +) from supertokens_python.recipe.session.interfaces import SessionContainer # isort: off @@ -171,6 +178,7 @@ def __init__( ) ), dashboard.init(), + userroles.init(), ], mode="asgi", ) @@ -183,6 +191,10 @@ def __init__( RequestValidationError, self.request_validation_error_handler, # type: ignore ) + self.add_exception_handler( + GeneralError, + self.general_error_handler, # type: ignore + ) # SuperTokens recipes overrides @@ -386,9 +398,7 @@ async def request_validation_error_handler( ) -> ORJSONResponse: message = RequestValidationErrorMessage( errors=[ - RequestValidationErrorDetails( - detail=exception["msg"], context=exception["ctx"]["error"] - ) + RequestValidationErrorDetails(detail=exception["msg"], context="") for exception in exc.errors() ] ) @@ -397,6 +407,13 @@ async def request_validation_error_handler( content=message.model_dump(), status_code=status.HTTP_400_BAD_REQUEST ) + async def general_error_handler( + self, request: RouteRequest, exc: GeneralError + ) -> ORJSONResponse: + return ORJSONResponse( + content={"error": str(exc)}, status_code=status.HTTP_400_BAD_REQUEST + ) + ### Server-related utilities @asynccontextmanager diff --git a/server/routes/events.py b/server/routes/events.py index 38afa07..1f98c51 100644 --- a/server/routes/events.py +++ b/server/routes/events.py @@ -9,6 +9,7 @@ from utils.errors import NotFoundException, NotFoundMessage from utils.pages import KanaePages, KanaeParams, paginate from utils.request import RouteRequest +from utils.roles import has_any_role from utils.router import KanaeRouter router = KanaeRouter(tags=["Events"]) @@ -100,11 +101,11 @@ class ModifiedEventWithDatetime(ModifiedEvent): end_at: datetime.datetime -# Depends on scopes @router.put( "/events/{id}", responses={200: {"model": EventsWithID}, 404: {"model": NotFoundMessage}}, ) +@has_any_role("admin", "leads") @router.limiter.limit("10/minute") async def edit_event( request: RouteRequest, @@ -150,11 +151,11 @@ class DeleteResponse(BaseModel, frozen=True): message: str = "ok" -# Depends on scopes @router.delete( "/events/{id}", responses={200: {"model": DeleteResponse}, 404: {"model": NotFoundMessage}}, ) +@has_any_role("admin", "leads") @router.limiter.limit("10/minute") async def delete_event( request: RouteRequest, @@ -173,8 +174,8 @@ async def delete_event( return DeleteResponse() -# Depends on scopes @router.post("/events/create", responses={200: {"model": EventsWithAllID}}) +@has_any_role("admin", "leads") @router.limiter.limit("15/minute") async def create_events( request: RouteRequest, diff --git a/server/routes/projects.py b/server/routes/projects.py index e4482cd..c9a8750 100644 --- a/server/routes/projects.py +++ b/server/routes/projects.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from supertokens_python.recipe.session import SessionContainer from supertokens_python.recipe.session.framework.fastapi import verify_session +from supertokens_python.recipe.userroles import UserRoleClaim from utils.errors import ( BadRequestException, HTTPExceptionMessage, @@ -16,6 +17,7 @@ from utils.pages import KanaePages, KanaeParams, paginate from utils.request import RouteRequest from utils.responses import DeleteResponse +from utils.roles import has_admin_role, has_any_role from utils.router import KanaeRouter router = KanaeRouter(tags=["Projects"]) @@ -107,6 +109,7 @@ async def list_projects( args.extend((until, active)) constraint = f"WHERE {time_constraint} GROUP BY projects.id" + # ruff: noqa: S608 query = f""" SELECT projects.id, projects.name, projects.description, projects.link, @@ -150,11 +153,11 @@ class ModifiedProject(BaseModel): link: str -# Depends on scopes - Requires project lead and/or admin scopes @router.put( "/projects/{id}", responses={200: {"model": Projects}, 404: {"model": NotFoundMessage}}, ) +@has_any_role("admin", "leads") @router.limiter.limit("3/minute") async def edit_project( request: RouteRequest, @@ -164,7 +167,6 @@ async def edit_project( ): """Updates the specified project""" - # todo: add query for admins query = """ WITH project_member AS ( SELECT members.id, members.role @@ -189,20 +191,41 @@ async def edit_project( RETURNING *; """ - rows = await request.app.pool.fetchrow( - query, id, session.get_user_id(), *req.model_dump().values() - ) + roles = await session.get_claim_value(UserRoleClaim) + + if roles and "admin" in roles: + # Effectively admins can override projects + query = """ + WITH project_member AS ( + SELECT members.id, members.role + FROM projects + INNER JOIN project_members ON project_members.project_id = projects.id + INNER JOIN members ON project_members.member_id = members.id + WHERE projects.id = $1 + ) + UPDATE projects + SET + name = $2, + description = $3, + link = $4 + WHERE + id = $1 + RETURNING *; + """ + + args = (id) if roles and "admin" in roles else (id, session.get_user_id()) + rows = await request.app.pool.fetchrow(query, *args, *req.model_dump().values()) if not rows: raise NotFoundException(detail="Resource cannot be updated") return Projects(**dict(rows)) -# Depends on scopes. Only admins should be able to delete them. @router.delete( "/projects/{id}", responses={200: {"model": DeleteResponse}, 400: {"model": NotFoundMessage}}, ) +@has_admin_role() @router.limiter.limit("3/minute") async def delete_project( request: RouteRequest, @@ -238,11 +261,11 @@ class CreateProject(BaseModel): founded_at: datetime.datetime -# Depends on roles, admins can only use this endpoint @router.post( "/projects/create", responses={200: {"model": PartialProjects}, 422: {"model": HTTPExceptionMessage}}, ) +@has_admin_role() @router.limiter.limit("5/minute") async def create_project( request: RouteRequest, @@ -337,7 +360,6 @@ class BulkJoinMember(BaseModel): id: uuid.UUID -# Depends on admin roles @router.post( "/projects/{id}/bulk-join", responses={ @@ -346,6 +368,7 @@ class BulkJoinMember(BaseModel): 409: {"model": HTTPExceptionMessage}, }, ) +@has_any_role("admin", "leads") @router.limiter.limit("1/minute") async def bulk_join_project( request: RouteRequest, @@ -421,6 +444,7 @@ class UpgradeMemberRole(BaseModel): include_in_schema=False, responses={200: {"model": DeleteResponse}}, ) +@has_admin_role() @router.limiter.limit("3/minute") async def modify_member( request: RouteRequest, diff --git a/server/routes/tags.py b/server/routes/tags.py index 26570b1..606454c 100644 --- a/server/routes/tags.py +++ b/server/routes/tags.py @@ -10,6 +10,7 @@ ) from utils.request import RouteRequest from utils.responses import DeleteResponse +from utils.roles import has_admin_role from utils.router import KanaeRouter router = KanaeRouter(tags=["Tags"]) @@ -73,8 +74,14 @@ class ModifiedTag(BaseModel): "/tags/{id}", responses={200: {"model": Tags}, 404: {"model": NotFoundMessage}}, ) +@has_admin_role() @router.limiter.limit("5/minute") -async def edit_tag(request: RouteRequest, id: int, req: ModifiedTag) -> Tags: +async def edit_tag( + request: RouteRequest, + id: int, + req: ModifiedTag, + session: Annotated[SessionContainer, Depends(verify_session())], +) -> Tags: """Modify specified tag""" query = """ UPDATE tags @@ -94,6 +101,7 @@ async def edit_tag(request: RouteRequest, id: int, req: ModifiedTag) -> Tags: "/tags/{id}", responses={200: {"model": DeleteResponse}, 404: {"model": NotFoundMessage}}, ) +@has_admin_role() @router.limiter.limit("5/minute") async def delete_tag( request: RouteRequest, @@ -113,6 +121,7 @@ async def delete_tag( @router.post("/tags/create", responses={200: {"model": Tags}}) +@has_admin_role() @router.limiter.limit("5/minute") async def create_tags( request: RouteRequest, @@ -130,6 +139,7 @@ async def create_tags( @router.post("/tags/bulk-create", responses={200: {"model": list[Tags]}}) +@has_admin_role() @router.limiter.limit("1/minute") async def bulk_create_tags( request: RouteRequest, diff --git a/server/routes/user.py b/server/routes/user.py index 9ffd7aa..2e98e1d 100644 --- a/server/routes/user.py +++ b/server/routes/user.py @@ -20,7 +20,6 @@ class GetUser(BaseModel): responses={200: {"model": GetUser}, 404: {"model": NotFound}}, name="Get users", ) -@router.limiter.limit("1/minute") async def get_users(request: RouteRequest) -> GetUser: query = "SELECT 1;" status = await request.app.pool.execute(query) diff --git a/server/utils/roles.py b/server/utils/roles.py new file mode 100644 index 0000000..321f2d1 --- /dev/null +++ b/server/utils/roles.py @@ -0,0 +1,101 @@ +# Current scopes: +# read: +# all +# projects +# events +# tags +# write: +# all +# events +# projects +# --------------- +# And current roles: admin, leads +import functools +import inspect +from typing import Any, Callable, Coroutine, Optional, TypeVar + +from supertokens_python.exceptions import GeneralError +from supertokens_python.recipe.session import SessionContainer +from supertokens_python.recipe.session.exceptions import ( + ClaimValidationError, + InvalidClaimsError, +) +from supertokens_python.recipe.userroles import UserRoleClaim + +T = TypeVar("T") + +Coro = Coroutine[Any, Any, T] +CoroFunc = Callable[..., Coro[Any]] + + +def validate_parameters(func: CoroFunc): + sig = inspect.signature(func) + if not sig.parameters.get("session"): + raise GeneralError( + f"No argument found within function <{func.__name__}>" + ) + + +def has_role(item: str, /): + def decorator(func: CoroFunc) -> CoroFunc: + validate_parameters(func) + + @functools.wraps(func) + async def wrapper( + session: Optional[SessionContainer], *args, **kwargs + ) -> CoroFunc: + if not session: + raise GeneralError("Must have valid session") + + roles = await session.get_claim_value(UserRoleClaim) + if not roles or item not in roles: + raise InvalidClaimsError( + f"User does not have role <{item}>", + [ClaimValidationError(UserRoleClaim.key, None)], + ) + + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +def has_any_role(*items: str): + def decorator(func: CoroFunc) -> CoroFunc: + validate_parameters(func) + + @functools.wraps(func) + async def wrapper( + session: Optional[SessionContainer], *args, **kwargs + ) -> CoroFunc: + if not session: + raise GeneralError("Must have valid session") + + user_roles = await session.get_claim_value(UserRoleClaim) + + if not user_roles: + raise InvalidClaimsError( + f"User does not any roles listed: {', '.join(role for role in items).rstrip()}", + [ClaimValidationError(UserRoleClaim.key, None)], + ) + if not any(role in user_roles for role in items): + # May need to be tested more + raise InvalidClaimsError( + f"Missing Roles: {', '.join(role for role in items if role not in user_roles).rstrip()}", + [ClaimValidationError(UserRoleClaim.key, None)], + ) + + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +def has_admin_role(): + return has_role("admin") + + +def has_leads_role(): + return has_role("leads")