Skip to content

Commit

Permalink
Revert "feat: Implement required_roles"
Browse files Browse the repository at this point in the history
This reverts commit c09ca3f.
  • Loading branch information
jopemachine committed Feb 3, 2025
1 parent 5f479d9 commit d7d1169
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 105 deletions.
68 changes: 0 additions & 68 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,74 +1013,6 @@ async def wrapped(
return wrap


def required_roles(
roles: UserRole | list[UserRole],
field_name: str | None = None,
):
"""
A flexible function that can act as either:
1) A decorator for custom resolvers
2) A resolver argument for simple fields (using the 'field_name' parameter)
Usage:
------
1) Decorator form:
@require_roles([UserRole.SUPERADMIN, UserRole.ADMIN])
async def resolve_something(root, info, *args, **kwargs):
# original resolver logic
return ...
2) Resolver argument form (for simple fields):
myfield = graphene.String(
resolver=require_roles([UserRole.SUPERADMIN], "myfield")
)
Parameters:
-----------
roles: UserRole | list[UserRole]
A single role or a list of roles required to access the field or resolver.
field_name: str | None
If provided, returns a resolver function that fetches `field_name` from `root`.
If None, returns a decorator for custom resolver functions.
Returns:
--------
- An async resolver function if `field_name` is set.
- A decorator function if `field_name` is None.
"""
from .user import UserRole

if isinstance(roles, UserRole):
roles = [roles]

def decorator(func):
"""Decorator that checks user role before running 'func'."""

@functools.wraps(func)
async def wrapper(root, info, *args, **kwargs):
ctx = info.context
user_role: UserRole = ctx.user["role"]
if user_role not in roles:
raise GenericForbidden(
f"One of {roles} permission is required. Current role: {user_role}"
)
return await func(root, info, *args, **kwargs)

return wrapper

# In case of "field_name" is provided, it returns a dynamic resolver function.
if field_name is not None:

@decorator
async def dynamic_resolver(root, info, *args, **kwargs):
return getattr(root, field_name, None)

return dynamic_resolver

# Otherwise, it returns the decorator function.
return decorator


def scoped_query(
*,
autofill_user: bool = False,
Expand Down
48 changes: 11 additions & 37 deletions src/ai/backend/manager/models/container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
PaginatedConnectionField,
StrEnumType,
generate_sql_info_for_gql_connection,
required_roles,
set_if_set,
)
from .gql_models.group import GroupConnection, GroupNode
Expand Down Expand Up @@ -376,43 +375,18 @@ class Meta:
description = "Added in 24.09.0."

row_id = graphene.UUID(
description="Added in 24.09.0. The UUID type id of DB container_registries row.",
resolver=required_roles(UserRole.SUPERADMIN, "row_id"),
)
name = graphene.String(resolver=required_roles(UserRole.SUPERADMIN, "name"))
url = graphene.String(
required=True,
description="Added in 24.09.0.",
resolver=required_roles(UserRole.SUPERADMIN, "url"),
)
type = ContainerRegistryTypeField(
required=True,
description="Added in 24.09.0.",
resolver=required_roles(UserRole.SUPERADMIN, "type"),
)
registry_name = graphene.String(
required=True,
description="Added in 24.09.0.",
resolver=required_roles(UserRole.SUPERADMIN, "registry_name"),
)
is_global = graphene.Boolean(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "is_global")
)
project = graphene.String(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "project")
)
username = graphene.String(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "username")
)
password = graphene.String(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "password")
)
ssl_verify = graphene.Boolean(
description="Added in 24.09.0.", resolver=required_roles(UserRole.SUPERADMIN, "ssl_verify")
)
extra = graphene.JSONString(
description="Added in 24.09.3.", resolver=required_roles(UserRole.SUPERADMIN, "extra")
description="Added in 24.09.0. The UUID type id of DB container_registries row."
)
name = graphene.String()
url = graphene.String(required=True, description="Added in 24.09.0.")
type = ContainerRegistryTypeField(required=True, description="Added in 24.09.0.")
registry_name = graphene.String(required=True, description="Added in 24.09.0.")
is_global = graphene.Boolean(description="Added in 24.09.0.")
project = graphene.String(description="Added in 24.09.0.")
username = graphene.String(description="Added in 24.09.0.")
password = graphene.String(description="Added in 24.09.0.")
ssl_verify = graphene.Boolean(description="Added in 24.09.0.")
extra = graphene.JSONString(description="Added in 24.09.3.")
allowed_groups = PaginatedConnectionField(GroupConnection, description="Added in 25.2.0.")

_queryfilter_fieldspec: dict[str, FieldSpecItem] = {
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,7 @@ async def resolve_container_registries(
return await ContainerRegistry.load_all(ctx)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_container_registry_node(
root: Any,
info: graphene.ResolveInfo,
Expand All @@ -2574,6 +2575,7 @@ async def resolve_container_registry_node(
return await ContainerRegistryNode.get_node(info, id)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_container_registry_nodes(
root: Any,
info: graphene.ResolveInfo,
Expand Down

0 comments on commit d7d1169

Please sign in to comment.