Skip to content

Commit

Permalink
Implement administrative and leads user roles (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
No767 authored Feb 10, 2025
1 parent 01a1a03 commit a4b4c7a
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 17 deletions.
25 changes: 21 additions & 4 deletions server/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
)
from supertokens_python.asyncio import list_users_by_account_info
from supertokens_python.auth_utils import LinkingToSessionUserFailedError
from supertokens_python.recipe import dashboard, emailpassword, session, thirdparty
from supertokens_python.exceptions import GeneralError
from supertokens_python.recipe import (
dashboard,
emailpassword,
session,
thirdparty,
userroles,
)
from supertokens_python.recipe.session.interfaces import SessionContainer

# isort: off
Expand Down Expand Up @@ -171,6 +178,7 @@ def __init__(
)
),
dashboard.init(),
userroles.init(),
],
mode="asgi",
)
Expand All @@ -183,6 +191,10 @@ def __init__(
RequestValidationError,
self.request_validation_error_handler, # type: ignore
)
self.add_exception_handler(
GeneralError,
self.general_error_handler, # type: ignore
)

# SuperTokens recipes overrides

Expand Down Expand Up @@ -386,9 +398,7 @@ async def request_validation_error_handler(
) -> ORJSONResponse:
message = RequestValidationErrorMessage(
errors=[
RequestValidationErrorDetails(
detail=exception["msg"], context=exception["ctx"]["error"]
)
RequestValidationErrorDetails(detail=exception["msg"], context="")
for exception in exc.errors()
]
)
Expand All @@ -397,6 +407,13 @@ async def request_validation_error_handler(
content=message.model_dump(), status_code=status.HTTP_400_BAD_REQUEST
)

async def general_error_handler(
self, request: RouteRequest, exc: GeneralError
) -> ORJSONResponse:
return ORJSONResponse(
content={"error": str(exc)}, status_code=status.HTTP_400_BAD_REQUEST
)

### Server-related utilities

@asynccontextmanager
Expand Down
7 changes: 4 additions & 3 deletions server/routes/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from utils.errors import NotFoundException, NotFoundMessage
from utils.pages import KanaePages, KanaeParams, paginate
from utils.request import RouteRequest
from utils.roles import has_any_role
from utils.router import KanaeRouter

router = KanaeRouter(tags=["Events"])
Expand Down Expand Up @@ -100,11 +101,11 @@ class ModifiedEventWithDatetime(ModifiedEvent):
end_at: datetime.datetime


# Depends on scopes
@router.put(
"/events/{id}",
responses={200: {"model": EventsWithID}, 404: {"model": NotFoundMessage}},
)
@has_any_role("admin", "leads")
@router.limiter.limit("10/minute")
async def edit_event(
request: RouteRequest,
Expand Down Expand Up @@ -150,11 +151,11 @@ class DeleteResponse(BaseModel, frozen=True):
message: str = "ok"


# Depends on scopes
@router.delete(
"/events/{id}",
responses={200: {"model": DeleteResponse}, 404: {"model": NotFoundMessage}},
)
@has_any_role("admin", "leads")
@router.limiter.limit("10/minute")
async def delete_event(
request: RouteRequest,
Expand All @@ -173,8 +174,8 @@ async def delete_event(
return DeleteResponse()


# Depends on scopes
@router.post("/events/create", responses={200: {"model": EventsWithAllID}})
@has_any_role("admin", "leads")
@router.limiter.limit("15/minute")
async def create_events(
request: RouteRequest,
Expand Down
40 changes: 32 additions & 8 deletions server/routes/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel
from supertokens_python.recipe.session import SessionContainer
from supertokens_python.recipe.session.framework.fastapi import verify_session
from supertokens_python.recipe.userroles import UserRoleClaim
from utils.errors import (
BadRequestException,
HTTPExceptionMessage,
Expand All @@ -16,6 +17,7 @@
from utils.pages import KanaePages, KanaeParams, paginate
from utils.request import RouteRequest
from utils.responses import DeleteResponse
from utils.roles import has_admin_role, has_any_role
from utils.router import KanaeRouter

router = KanaeRouter(tags=["Projects"])
Expand Down Expand Up @@ -107,6 +109,7 @@ async def list_projects(
args.extend((until, active))
constraint = f"WHERE {time_constraint} GROUP BY projects.id"

# ruff: noqa: S608
query = f"""
SELECT
projects.id, projects.name, projects.description, projects.link,
Expand Down Expand Up @@ -150,11 +153,11 @@ class ModifiedProject(BaseModel):
link: str


# Depends on scopes - Requires project lead and/or admin scopes
@router.put(
"/projects/{id}",
responses={200: {"model": Projects}, 404: {"model": NotFoundMessage}},
)
@has_any_role("admin", "leads")
@router.limiter.limit("3/minute")
async def edit_project(
request: RouteRequest,
Expand All @@ -164,7 +167,6 @@ async def edit_project(
):
"""Updates the specified project"""

# todo: add query for admins
query = """
WITH project_member AS (
SELECT members.id, members.role
Expand All @@ -189,20 +191,41 @@ async def edit_project(
RETURNING *;
"""

rows = await request.app.pool.fetchrow(
query, id, session.get_user_id(), *req.model_dump().values()
)
roles = await session.get_claim_value(UserRoleClaim)

if roles and "admin" in roles:
# Effectively admins can override projects
query = """
WITH project_member AS (
SELECT members.id, members.role
FROM projects
INNER JOIN project_members ON project_members.project_id = projects.id
INNER JOIN members ON project_members.member_id = members.id
WHERE projects.id = $1
)
UPDATE projects
SET
name = $2,
description = $3,
link = $4
WHERE
id = $1
RETURNING *;
"""

args = (id) if roles and "admin" in roles else (id, session.get_user_id())
rows = await request.app.pool.fetchrow(query, *args, *req.model_dump().values())

if not rows:
raise NotFoundException(detail="Resource cannot be updated")
return Projects(**dict(rows))


# Depends on scopes. Only admins should be able to delete them.
@router.delete(
"/projects/{id}",
responses={200: {"model": DeleteResponse}, 400: {"model": NotFoundMessage}},
)
@has_admin_role()
@router.limiter.limit("3/minute")
async def delete_project(
request: RouteRequest,
Expand Down Expand Up @@ -238,11 +261,11 @@ class CreateProject(BaseModel):
founded_at: datetime.datetime


# Depends on roles, admins can only use this endpoint
@router.post(
"/projects/create",
responses={200: {"model": PartialProjects}, 422: {"model": HTTPExceptionMessage}},
)
@has_admin_role()
@router.limiter.limit("5/minute")
async def create_project(
request: RouteRequest,
Expand Down Expand Up @@ -337,7 +360,6 @@ class BulkJoinMember(BaseModel):
id: uuid.UUID


# Depends on admin roles
@router.post(
"/projects/{id}/bulk-join",
responses={
Expand All @@ -346,6 +368,7 @@ class BulkJoinMember(BaseModel):
409: {"model": HTTPExceptionMessage},
},
)
@has_any_role("admin", "leads")
@router.limiter.limit("1/minute")
async def bulk_join_project(
request: RouteRequest,
Expand Down Expand Up @@ -421,6 +444,7 @@ class UpgradeMemberRole(BaseModel):
include_in_schema=False,
responses={200: {"model": DeleteResponse}},
)
@has_admin_role()
@router.limiter.limit("3/minute")
async def modify_member(
request: RouteRequest,
Expand Down
12 changes: 11 additions & 1 deletion server/routes/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from utils.request import RouteRequest
from utils.responses import DeleteResponse
from utils.roles import has_admin_role
from utils.router import KanaeRouter

router = KanaeRouter(tags=["Tags"])
Expand Down Expand Up @@ -73,8 +74,14 @@ class ModifiedTag(BaseModel):
"/tags/{id}",
responses={200: {"model": Tags}, 404: {"model": NotFoundMessage}},
)
@has_admin_role()
@router.limiter.limit("5/minute")
async def edit_tag(request: RouteRequest, id: int, req: ModifiedTag) -> Tags:
async def edit_tag(
request: RouteRequest,
id: int,
req: ModifiedTag,
session: Annotated[SessionContainer, Depends(verify_session())],
) -> Tags:
"""Modify specified tag"""
query = """
UPDATE tags
Expand All @@ -94,6 +101,7 @@ async def edit_tag(request: RouteRequest, id: int, req: ModifiedTag) -> Tags:
"/tags/{id}",
responses={200: {"model": DeleteResponse}, 404: {"model": NotFoundMessage}},
)
@has_admin_role()
@router.limiter.limit("5/minute")
async def delete_tag(
request: RouteRequest,
Expand All @@ -113,6 +121,7 @@ async def delete_tag(


@router.post("/tags/create", responses={200: {"model": Tags}})
@has_admin_role()
@router.limiter.limit("5/minute")
async def create_tags(
request: RouteRequest,
Expand All @@ -130,6 +139,7 @@ async def create_tags(


@router.post("/tags/bulk-create", responses={200: {"model": list[Tags]}})
@has_admin_role()
@router.limiter.limit("1/minute")
async def bulk_create_tags(
request: RouteRequest,
Expand Down
1 change: 0 additions & 1 deletion server/routes/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class GetUser(BaseModel):
responses={200: {"model": GetUser}, 404: {"model": NotFound}},
name="Get users",
)
@router.limiter.limit("1/minute")
async def get_users(request: RouteRequest) -> GetUser:
query = "SELECT 1;"
status = await request.app.pool.execute(query)
Expand Down
101 changes: 101 additions & 0 deletions server/utils/roles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Current scopes:
# read:
# all
# projects
# events
# tags
# write:
# all
# events
# projects
# ---------------
# And current roles: admin, leads
import functools
import inspect
from typing import Any, Callable, Coroutine, Optional, TypeVar

from supertokens_python.exceptions import GeneralError
from supertokens_python.recipe.session import SessionContainer
from supertokens_python.recipe.session.exceptions import (
ClaimValidationError,
InvalidClaimsError,
)
from supertokens_python.recipe.userroles import UserRoleClaim

T = TypeVar("T")

Coro = Coroutine[Any, Any, T]
CoroFunc = Callable[..., Coro[Any]]


def validate_parameters(func: CoroFunc):
sig = inspect.signature(func)
if not sig.parameters.get("session"):
raise GeneralError(
f"No <session> argument found within function <{func.__name__}>"
)


def has_role(item: str, /):
def decorator(func: CoroFunc) -> CoroFunc:
validate_parameters(func)

@functools.wraps(func)
async def wrapper(
session: Optional[SessionContainer], *args, **kwargs
) -> CoroFunc:
if not session:
raise GeneralError("Must have valid session")

roles = await session.get_claim_value(UserRoleClaim)
if not roles or item not in roles:
raise InvalidClaimsError(
f"User does not have role <{item}>",
[ClaimValidationError(UserRoleClaim.key, None)],
)

return await func(*args, **kwargs)

return wrapper

return decorator


def has_any_role(*items: str):
def decorator(func: CoroFunc) -> CoroFunc:
validate_parameters(func)

@functools.wraps(func)
async def wrapper(
session: Optional[SessionContainer], *args, **kwargs
) -> CoroFunc:
if not session:
raise GeneralError("Must have valid session")

user_roles = await session.get_claim_value(UserRoleClaim)

if not user_roles:
raise InvalidClaimsError(
f"User does not any roles listed: {', '.join(role for role in items).rstrip()}",
[ClaimValidationError(UserRoleClaim.key, None)],
)
if not any(role in user_roles for role in items):
# May need to be tested more
raise InvalidClaimsError(
f"Missing Roles: {', '.join(role for role in items if role not in user_roles).rstrip()}",
[ClaimValidationError(UserRoleClaim.key, None)],
)

return await func(*args, **kwargs)

return wrapper

return decorator


def has_admin_role():
return has_role("admin")


def has_leads_role():
return has_role("leads")

0 comments on commit a4b4c7a

Please sign in to comment.