diff --git a/gqlauth/core/directives.py b/gqlauth/core/directives.py index cdd2f179..70eaff69 100644 --- a/gqlauth/core/directives.py +++ b/gqlauth/core/directives.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, List, Optional from django.contrib.auth import get_user_model -from django.contrib.auth.models import AnonymousUser from django.utils.translation import gettext as _ from jwt import PyJWTError import strawberry @@ -11,7 +10,7 @@ from gqlauth.core.exceptions import TokenExpired from gqlauth.core.types_ import GQLAuthError, GQLAuthErrors -from gqlauth.core.utils import get_status +from gqlauth.core.utils import get_status, get_user from gqlauth.jwt.types_ import TokenType from gqlauth.settings import gqlauth_settings as app_settings @@ -24,7 +23,6 @@ class BaseAuthDirective(ABC): @abstractmethod def resolve_permission( self, - user: Union[USER_MODEL, AnonymousUser], source: Any, info: Info, args: list, @@ -41,12 +39,7 @@ def resolve_permission( ) class TokenRequired(BaseAuthDirective): def resolve_permission( - self, - user: Union[USER_MODEL, AnonymousUser], - source: Any, - info: Info, - args: list, - kwargs: dict, + self, source: Any, info: Info, args: list, kwargs: dict ) -> Optional[GQLAuthError]: token = app_settings.JWT_TOKEN_FINDER(info) try: @@ -68,7 +61,8 @@ def resolve_permission( description="This field requires authentication", ) class IsAuthenticated(BaseAuthDirective): - def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, **kwargs): + def resolve_permission(self, source: Any, info: Info, args, kwargs): + user = get_user(info) if not user.is_authenticated: return GQLAuthError(code=GQLAuthErrors.UNAUTHENTICATED) return None @@ -81,7 +75,8 @@ def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, * description="This field requires the user to be verified", ) class IsVerified(BaseAuthDirective): - def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, **kwargs): + def resolve_permission(self, source: Any, info: Info, args, kwargs): + user = get_user(info) if (status := get_status(user)) and status.verified: return None return GQLAuthError(code=GQLAuthErrors.NOT_VERIFIED) @@ -96,7 +91,8 @@ def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, * class HasPermission(BaseAuthDirective): permissions: strawberry.Private[List[str]] - def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, **kwargs): + def resolve_permission(self, source: Any, info: Info, args, kwargs): + user = get_user(info) for permission in self.permissions: if not user.has_perm(permission): return GQLAuthError( diff --git a/gqlauth/core/field_.py b/gqlauth/core/field_.py index b9e91bdd..6200af9b 100644 --- a/gqlauth/core/field_.py +++ b/gqlauth/core/field_.py @@ -11,7 +11,6 @@ from gqlauth.core.directives import BaseAuthDirective from gqlauth.core.types_ import GQLAuthError -from gqlauth.core.utils import get_user __all__ = ["field"] @@ -20,10 +19,9 @@ class GqlAuthField(StrawberryDjangoField): def _resolve(self, source, info, args, kwargs) -> Union[GQLAuthError, Any]: - user = get_user(info) for directive in self.directives: if isinstance(directive, BaseAuthDirective) and ( - error := directive.resolve_permission(user, source, info, args, kwargs) + error := directive.resolve_permission(source, info, args, kwargs) ): return error return super().get_result(source, info, args, kwargs) @@ -31,11 +29,10 @@ def _resolve(self, source, info, args, kwargs) -> Union[GQLAuthError, Any]: async def _resolve_subscriptions( self, source, info, args, kwargs ) -> Union[AsyncGenerator, GQLAuthError]: - user = get_user(info) for directive in self.directives: if isinstance(directive, BaseAuthDirective) and ( error := await sync_to_async(directive.resolve_permission)( - user, source, info, args, kwargs + source, info, args, kwargs ) ): yield error diff --git a/pyproject.toml b/pyproject.toml index 89ec6910..9f0edc2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "strawberry-django-auth" -version = "0.3.5.1" +version = "0.3.5.2" description = "Graphql authentication system with Strawberry for Django." license = "MIT" authors = ["Nir.J Benlulu "] diff --git a/testproject/schema.py b/testproject/schema.py index 7aa5552f..1421cb88 100644 --- a/testproject/schema.py +++ b/testproject/schema.py @@ -79,6 +79,11 @@ class Query: def auth_entry(self) -> Union[GQLAuthError, AuthQueries]: return AuthQueries() + @field(directives=[TokenRequired(), IsVerified()]) + @staticmethod + def batched_field() -> Union["AppleType", GQLAuthError]: + return Apple.objects.latest("pk") + @strawberry.type class Integer: diff --git a/tests/test_auth_directives.py b/tests/test_auth_directives.py index 34caabe1..02cb8de0 100644 --- a/tests/test_auth_directives.py +++ b/tests/test_auth_directives.py @@ -1,5 +1,5 @@ from django.contrib.auth import get_user_model -from django.contrib.auth.models import AnonymousUser +from strawberry.types import Info from gqlauth.core.directives import HasPermission, IsAuthenticated, IsVerified from gqlauth.core.types_ import GQLAuthErrors @@ -8,27 +8,46 @@ USER_MODEL = get_user_model() +class DottedDict(dict): + def __getattr__(self, item): + return self.__getitem__(item) + + +def fake_info(user) -> Info: + res = DottedDict( + {"context": DottedDict({"user": user}), "path": DottedDict({"key": "some field"})} + ) + return res + + class TestAuthDirectives(ArgTestCase): def test_is_authenticated_fails(self): - res = IsAuthenticated().resolve_permission(AnonymousUser(), None, None) + res = IsAuthenticated().resolve_permission(None, fake_info(None), None, None) assert res.code == GQLAuthErrors.UNAUTHENTICATED assert res.message == GQLAuthErrors.UNAUTHENTICATED.value def test_is_authenticated_success(self, db_verified_user_status): assert ( - IsAuthenticated().resolve_permission(db_verified_user_status.user.obj, None, None) + IsAuthenticated().resolve_permission( + None, fake_info(db_verified_user_status.user.obj), None, None + ) is None ) def test_is_verified_fails(self, db_unverified_user_status): - res = IsVerified().resolve_permission(db_unverified_user_status.user.obj, None, None) + res = IsVerified().resolve_permission( + None, fake_info(db_unverified_user_status.user.obj), None, None + ) assert res.code == GQLAuthErrors.NOT_VERIFIED assert res.message == GQLAuthErrors.NOT_VERIFIED.value def test_is_verified_success(self, db_verified_user_status): assert ( IsVerified().resolve_permission( - db_verified_user_status.user.obj, None, None, None, None + None, + fake_info(db_verified_user_status.user.obj), + None, + None, ) is None ) @@ -41,14 +60,8 @@ def test_has_permission_fails(self, db_verified_user_status): ] ) - class FakePath: - key = "test" - - class FakeInfo: - path = FakePath - assert ( - perm.resolve_permission(user, None, FakeInfo).code + perm.resolve_permission(None, fake_info(user), None, None).code is GQLAuthErrors.NO_SUFFICIENT_PERMISSIONS ) @@ -59,7 +72,7 @@ def test_has_permission_success(self, db_verified_user_status_can_eat): "sample.can_eat", ] ) - assert perm.resolve_permission(user, None, None) is None + assert perm.resolve_permission(None, fake_info(user), None, None) is None class IsVerifiedDirectivesInSchemaMixin(AbstractTestCase): diff --git a/tests/test_gqlauth_fields.py b/tests/test_gqlauth_fields.py index 58e93712..05d1d089 100644 --- a/tests/test_gqlauth_fields.py +++ b/tests/test_gqlauth_fields.py @@ -21,6 +21,36 @@ def test_expired_token(self, app_settings, db_verified_user_status): "message": GQLAuthErrors.EXPIRED_TOKEN.value, } + batched_field_query = """ + query MyQuery { + batchedField { + ... on AppleType { + __typename + name + isEaten + color + } + ... on GQLAuthError { + __typename + code + message + } + } + } + """ + + def test_batched_field_fails( + self, db_unverified_user_status, db_apple, allow_login_not_verified + ): + res = self.make_request( + query=self.batched_field_query, user_status=db_unverified_user_status + ) + assert res["code"] == GQLAuthErrors.NOT_VERIFIED.name + + def test_batched_field_success(self, db_verified_user_status, db_apple): + res = self.make_request(query=self.batched_field_query, user_status=db_verified_user_status) + assert res["name"] == db_apple.name + class TestGqlAuthRootFieldInSchema(GqlAuthFieldInSchemaMixin, ArgTestCase): ...