Skip to content

Commit

Permalink
Merge pull request #120 from nrbnlulu/fix-directives-batching
Browse files Browse the repository at this point in the history
fix batched fields with no user.
  • Loading branch information
nrbnlulu authored Sep 21, 2022
2 parents 90e925c + 759d809 commit 2123527
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 32 deletions.
22 changes: 9 additions & 13 deletions gqlauth/core/directives.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from typing import Any, List, Optional

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from django.utils.translation import gettext as _
from jwt import PyJWTError
import strawberry
Expand All @@ -11,7 +10,7 @@

from gqlauth.core.exceptions import TokenExpired
from gqlauth.core.types_ import GQLAuthError, GQLAuthErrors
from gqlauth.core.utils import get_status
from gqlauth.core.utils import get_status, get_user
from gqlauth.jwt.types_ import TokenType
from gqlauth.settings import gqlauth_settings as app_settings

Expand All @@ -24,7 +23,6 @@ class BaseAuthDirective(ABC):
@abstractmethod
def resolve_permission(
self,
user: Union[USER_MODEL, AnonymousUser],
source: Any,
info: Info,
args: list,
Expand All @@ -41,12 +39,7 @@ def resolve_permission(
)
class TokenRequired(BaseAuthDirective):
def resolve_permission(
self,
user: Union[USER_MODEL, AnonymousUser],
source: Any,
info: Info,
args: list,
kwargs: dict,
self, source: Any, info: Info, args: list, kwargs: dict
) -> Optional[GQLAuthError]:
token = app_settings.JWT_TOKEN_FINDER(info)
try:
Expand All @@ -68,7 +61,8 @@ def resolve_permission(
description="This field requires authentication",
)
class IsAuthenticated(BaseAuthDirective):
def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, **kwargs):
def resolve_permission(self, source: Any, info: Info, args, kwargs):
user = get_user(info)
if not user.is_authenticated:
return GQLAuthError(code=GQLAuthErrors.UNAUTHENTICATED)
return None
Expand All @@ -81,7 +75,8 @@ def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, *
description="This field requires the user to be verified",
)
class IsVerified(BaseAuthDirective):
def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, **kwargs):
def resolve_permission(self, source: Any, info: Info, args, kwargs):
user = get_user(info)
if (status := get_status(user)) and status.verified:
return None
return GQLAuthError(code=GQLAuthErrors.NOT_VERIFIED)
Expand All @@ -96,7 +91,8 @@ def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, *
class HasPermission(BaseAuthDirective):
permissions: strawberry.Private[List[str]]

def resolve_permission(self, user: USER_MODEL, source: Any, info: Info, *args, **kwargs):
def resolve_permission(self, source: Any, info: Info, args, kwargs):
user = get_user(info)
for permission in self.permissions:
if not user.has_perm(permission):
return GQLAuthError(
Expand Down
7 changes: 2 additions & 5 deletions gqlauth/core/field_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from gqlauth.core.directives import BaseAuthDirective
from gqlauth.core.types_ import GQLAuthError
from gqlauth.core.utils import get_user

__all__ = ["field"]

Expand All @@ -20,22 +19,20 @@

class GqlAuthField(StrawberryDjangoField):
def _resolve(self, source, info, args, kwargs) -> Union[GQLAuthError, Any]:
user = get_user(info)
for directive in self.directives:
if isinstance(directive, BaseAuthDirective) and (
error := directive.resolve_permission(user, source, info, args, kwargs)
error := directive.resolve_permission(source, info, args, kwargs)
):
return error
return super().get_result(source, info, args, kwargs)

async def _resolve_subscriptions(
self, source, info, args, kwargs
) -> Union[AsyncGenerator, GQLAuthError]:
user = get_user(info)
for directive in self.directives:
if isinstance(directive, BaseAuthDirective) and (
error := await sync_to_async(directive.resolve_permission)(
user, source, info, args, kwargs
source, info, args, kwargs
)
):
yield error
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "strawberry-django-auth"
version = "0.3.5.1"
version = "0.3.5.2"
description = "Graphql authentication system with Strawberry for Django."
license = "MIT"
authors = ["Nir.J Benlulu <[email protected]>"]
Expand Down
5 changes: 5 additions & 0 deletions testproject/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class Query:
def auth_entry(self) -> Union[GQLAuthError, AuthQueries]:
return AuthQueries()

@field(directives=[TokenRequired(), IsVerified()])
@staticmethod
def batched_field() -> Union["AppleType", GQLAuthError]:
return Apple.objects.latest("pk")


@strawberry.type
class Integer:
Expand Down
39 changes: 26 additions & 13 deletions tests/test_auth_directives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from strawberry.types import Info

from gqlauth.core.directives import HasPermission, IsAuthenticated, IsVerified
from gqlauth.core.types_ import GQLAuthErrors
Expand All @@ -8,27 +8,46 @@
USER_MODEL = get_user_model()


class DottedDict(dict):
def __getattr__(self, item):
return self.__getitem__(item)


def fake_info(user) -> Info:
res = DottedDict(
{"context": DottedDict({"user": user}), "path": DottedDict({"key": "some field"})}
)
return res


class TestAuthDirectives(ArgTestCase):
def test_is_authenticated_fails(self):
res = IsAuthenticated().resolve_permission(AnonymousUser(), None, None)
res = IsAuthenticated().resolve_permission(None, fake_info(None), None, None)
assert res.code == GQLAuthErrors.UNAUTHENTICATED
assert res.message == GQLAuthErrors.UNAUTHENTICATED.value

def test_is_authenticated_success(self, db_verified_user_status):
assert (
IsAuthenticated().resolve_permission(db_verified_user_status.user.obj, None, None)
IsAuthenticated().resolve_permission(
None, fake_info(db_verified_user_status.user.obj), None, None
)
is None
)

def test_is_verified_fails(self, db_unverified_user_status):
res = IsVerified().resolve_permission(db_unverified_user_status.user.obj, None, None)
res = IsVerified().resolve_permission(
None, fake_info(db_unverified_user_status.user.obj), None, None
)
assert res.code == GQLAuthErrors.NOT_VERIFIED
assert res.message == GQLAuthErrors.NOT_VERIFIED.value

def test_is_verified_success(self, db_verified_user_status):
assert (
IsVerified().resolve_permission(
db_verified_user_status.user.obj, None, None, None, None
None,
fake_info(db_verified_user_status.user.obj),
None,
None,
)
is None
)
Expand All @@ -41,14 +60,8 @@ def test_has_permission_fails(self, db_verified_user_status):
]
)

class FakePath:
key = "test"

class FakeInfo:
path = FakePath

assert (
perm.resolve_permission(user, None, FakeInfo).code
perm.resolve_permission(None, fake_info(user), None, None).code
is GQLAuthErrors.NO_SUFFICIENT_PERMISSIONS
)

Expand All @@ -59,7 +72,7 @@ def test_has_permission_success(self, db_verified_user_status_can_eat):
"sample.can_eat",
]
)
assert perm.resolve_permission(user, None, None) is None
assert perm.resolve_permission(None, fake_info(user), None, None) is None


class IsVerifiedDirectivesInSchemaMixin(AbstractTestCase):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_gqlauth_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,36 @@ def test_expired_token(self, app_settings, db_verified_user_status):
"message": GQLAuthErrors.EXPIRED_TOKEN.value,
}

batched_field_query = """
query MyQuery {
batchedField {
... on AppleType {
__typename
name
isEaten
color
}
... on GQLAuthError {
__typename
code
message
}
}
}
"""

def test_batched_field_fails(
self, db_unverified_user_status, db_apple, allow_login_not_verified
):
res = self.make_request(
query=self.batched_field_query, user_status=db_unverified_user_status
)
assert res["code"] == GQLAuthErrors.NOT_VERIFIED.name

def test_batched_field_success(self, db_verified_user_status, db_apple):
res = self.make_request(query=self.batched_field_query, user_status=db_verified_user_status)
assert res["name"] == db_apple.name


class TestGqlAuthRootFieldInSchema(GqlAuthFieldInSchemaMixin, ArgTestCase):
...
Expand Down

0 comments on commit 2123527

Please sign in to comment.