Skip to content

Commit

Permalink
feat: Implement field restricted graphene object type
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Jan 30, 2025
1 parent c99a090 commit 84328f7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
59 changes: 59 additions & 0 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from aiodataloader import DataLoader
from aiotools import apartial
from graphene.types import Scalar
from graphene.types.objecttype import ObjectTypeMeta
from graphene.types.scalars import MAX_INT, MIN_INT
from graphql import Undefined
from graphql.language.ast import IntValueNode
Expand Down Expand Up @@ -1013,6 +1014,64 @@ async def wrapped(
return wrap


def restricted_field_resolver(field_name: str):
from .user import UserRole

async def resolver(root, info, *args, **kwargs):
cls = type(root)
required_roles = getattr(cls, "required_roles_for_fields", {}).get(field_name)

if required_roles is None:
return getattr(root, field_name, None)

if isinstance(required_roles, UserRole):
required_roles = [required_roles]

ctx: GraphQueryContext = info.context
user_role: UserRole = ctx.user["role"]

if user_role not in required_roles:
raise GenericForbidden(f"Access denied for the '{field_name}' field")

return getattr(root, field_name, None)

return resolver


class FieldRestrictedMeta(ObjectTypeMeta):
def __new__(mcs, name, bases, attrs, **options):
cls = super().__new__(mcs, name, bases, attrs, **options)

for field_name, field_obj in cls._meta.fields.items():
# Skip if the field has a custom resolver
if hasattr(cls, f"resolve_{field_name}") or field_obj.resolver:
continue

field_obj.resolver = restricted_field_resolver(field_name)

return cls


class FieldRestrictedObjectType(graphene.ObjectType, metaclass=FieldRestrictedMeta):
"""
This base class automatically assigns a resolver to each field
that checks if the user's role is in the list (or single value)
defined in `required_roles_for_fields`.
Usage example in a subclass:
required_roles_for_fields = {
"some_field": UserRole.SUPERADMIN,
"another_field": [UserRole.SUPERADMIN, UserRole.ADMIN],
}
If the field is not listed in `required_roles_for_fields`, no special role is required.
Note that if there's already a custom resolver for a field, that field is skipped.
"""

pass


def scoped_query(
*,
autofill_user: bool = False,
Expand Down
17 changes: 16 additions & 1 deletion src/ai/backend/manager/models/container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from .base import (
Base,
FieldRestrictedObjectType,
FilterExprArg,
IDColumn,
OrderExprArg,
Expand Down Expand Up @@ -369,7 +370,7 @@ async def handle_allowed_groups_update(
raise ContainerRegistryNotFound()


class ContainerRegistryNode(graphene.ObjectType):
class ContainerRegistryNode(FieldRestrictedObjectType):
class Meta:
interfaces = (AsyncNode,)
description = "Added in 24.09.0."
Expand Down Expand Up @@ -398,6 +399,20 @@ class Meta:
"registry_name": ("registry_name", None),
}

required_roles_for_fields = {
"row_id": UserRole.SUPERADMIN,
"name": UserRole.SUPERADMIN,
"url": UserRole.SUPERADMIN,
"type": UserRole.SUPERADMIN,
"registry_name": UserRole.SUPERADMIN,
"is_global": UserRole.SUPERADMIN,
"project": UserRole.SUPERADMIN,
"username": UserRole.SUPERADMIN,
"password": UserRole.SUPERADMIN,
"ssl_verify": UserRole.SUPERADMIN,
"extra": UserRole.SUPERADMIN,
}

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ContainerRegistryNode:
graph_ctx: GraphQueryContext = info.context
Expand Down
2 changes: 0 additions & 2 deletions src/ai/backend/manager/models/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,6 @@ async def resolve_container_registries(
return await ContainerRegistry.load_all(ctx)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_container_registry_node(
root: Any,
info: graphene.ResolveInfo,
Expand All @@ -2575,7 +2574,6 @@ async def resolve_container_registry_node(
return await ContainerRegistryNode.get_node(info, id)

@staticmethod
@privileged_query(UserRole.SUPERADMIN)
async def resolve_container_registry_nodes(
root: Any,
info: graphene.ResolveInfo,
Expand Down

0 comments on commit 84328f7

Please sign in to comment.