Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Ability to use @skip @include graphql directives to exclude fields #231

Merged
merged 4 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions graphene_mongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .fields import MongoengineConnectionField
from .fields_async import AsyncMongoengineConnectionField

from .types import MongoengineObjectType, MongoengineInputType, MongoengineInterfaceType
from .types import MongoengineInputType, MongoengineInterfaceType, MongoengineObjectType
from .types_async import AsyncMongoengineObjectType

__version__ = "0.1.1"
__version__ = "0.4.2"

__all__ = [
"__version__",
Expand Down
114 changes: 35 additions & 79 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
find_skip_and_limit,
get_model_reference_fields,
get_query_fields,
has_page_info,
)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand Down Expand Up @@ -276,7 +277,7 @@ def fields(self):
return self._type._meta.fields

def get_queryset(
self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args
self, model, info, required_fields=None, skip=None, limit=None, **args
) -> QuerySet:
if required_fields is None:
required_fields = list()
Expand Down Expand Up @@ -325,49 +326,22 @@ def get_queryset(
else:
args.update(queryset_or_filters)
if limit is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip if skip else 0)
.limit(limit)
)
else:
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit(limit)
)
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit(limit)
)
elif skip is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip)
)
else:
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip)
)
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip)
)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
Expand Down Expand Up @@ -401,7 +375,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
skip = 0
count = 0
limit = None
reverse = False
first = args.pop("first", None)
after = args.pop("after", None)
if after:
Expand All @@ -410,14 +383,15 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
requires_page_info = has_page_info(info)
has_next_page = False

if resolved is not None:
items = resolved

if isinstance(items, QuerySet):
try:
if last is not None and after is not None:
if last is not None:
count = items.count(with_limit_and_skip=False)
else:
count = None
Expand All @@ -426,29 +400,24 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)

if isinstance(items, QuerySet):
if limit:
_base_query: QuerySet = (
items.order_by("-pk").skip(skip) if reverse else items.skip(skip)
)
_base_query: QuerySet = items.skip(skip)
items = _base_query.limit(limit)
has_next_page = len(_base_query.skip(limit).only("id").limit(1)) != 0
has_next_page = len(_base_query.skip(skip + limit).only("id").limit(1)) != 0
elif skip:
items = items.skip(skip)
else:
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
_base_query = items
items = items[skip : skip + limit]
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
elif skip:
items = items[skip:]
iterables = list(items)
Expand Down Expand Up @@ -503,11 +472,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
else:
count = self.model.objects(args_copy).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
self.model, info, required_fields, skip, limit, **args
)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
Expand All @@ -519,14 +488,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip : skip + limit]
else:
args["pk__in"] = args["pk__in"][skip : skip + limit]
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
Expand All @@ -542,18 +508,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
count = len(items)
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
iterables = items
Expand All @@ -567,11 +528,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
)
has_previous_page = True if skip else False

if reverse:
iterables = list(iterables)
iterables.reverse()
skip = limit

connection = connection_from_iterables(
edges=iterables,
start_offset=skip,
Expand Down
49 changes: 14 additions & 35 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
connection_from_iterables,
find_skip_and_limit,
get_query_fields,
sync_to_async,
has_page_info,
sync_to_async,
)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand Down Expand Up @@ -92,7 +92,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
skip = 0
count = 0
limit = None
reverse = False
first = args.pop("first", None)
after = args.pop("after", None)
if after:
Expand All @@ -109,7 +108,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

if isinstance(items, QuerySet):
try:
if last is not None and after is not None:
if last is not None:
count = await sync_to_async(items.count)(with_limit_and_skip=False)
else:
count = None
Expand All @@ -118,22 +117,18 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)

if isinstance(items, QuerySet):
if limit:
_base_query: QuerySet = (
await sync_to_async(items.order_by("-pk").skip)(skip)
if reverse
else await sync_to_async(items.skip)(skip)
)
_base_query: QuerySet = await sync_to_async(items.skip)(skip)
items = await sync_to_async(_base_query.limit)(limit)
has_next_page = (
(
await sync_to_async(len)(
await sync_to_async(_base_query.skip(limit).only("id").limit)(1)
_base_query.skip(skip + limit).only("id").limit(1)
)
!= 0
)
Expand All @@ -144,12 +139,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
items = await sync_to_async(items.skip)(skip)
else:
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
else:
_base_query = items
items = items[skip : skip + limit]
_base_query = items
items = items[skip : skip + limit]
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
Expand Down Expand Up @@ -200,11 +191,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = await sync_to_async(self.model.objects(args_copy).count)()
if count != 0:
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
self.model, info, required_fields, skip, limit, **args
)
iterables = await sync_to_async(list)(iterables)
list_length = len(iterables)
Expand All @@ -217,14 +208,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip : skip + limit]
else:
args["pk__in"] = args["pk__in"][skip : skip + limit]
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
Expand All @@ -241,16 +229,12 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
count = len(items)
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
else:
_base_query = items
items = items[skip : skip + limit]
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
Expand All @@ -266,11 +250,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
)
has_previous_page = True if requires_page_info and skip else False

if reverse:
iterables = await sync_to_async(list)(iterables)
iterables.reverse()
skip = limit

connection = connection_from_iterables(
edges=iterables,
start_offset=skip,
Expand Down
Loading