Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Support To Avoid Blocking Request #218

Merged
merged 4 commits into from
Apr 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ query = '''
}
}
'''
result = schema.execute(query)
result = await schema.execute_async(query)
```

To learn more check out the following [examples](examples/):
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Then you can simply query the schema:
}
}
'''
result = schema.execute(query)
result = await schema.execute_async(query)

To learn more check out the `Flask MongoEngine example <https://github.com/graphql-python/graphene-mongo/tree/master/examples/flask_mongoengine>`__

4 changes: 2 additions & 2 deletions examples/falcon_mongoengine/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def on_post(self, req, resp):
class GraphQLResource:
def on_get(self, req, resp):
query = req.params["query"]
result = schema.execute(query)
result = await schema.execute_async(query)

if result.data:
data_ret = {"data": result.data}
Expand All @@ -32,7 +32,7 @@ def on_get(self, req, resp):

def on_post(self, req, resp):
query = req.params["query"]
result = schema.execute(query)
result = await schema.execute_async(query)
if result.data:
data_ret = {"data": result.data}
resp.status = falcon.HTTP_200
Expand Down
62 changes: 33 additions & 29 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import advanced_types
from .utils import import_single_dispatch, get_field_description, get_query_fields
from concurrent.futures import ThreadPoolExecutor, as_completed
from asgiref.sync import sync_to_async

singledispatch = import_single_dispatch()

Expand Down Expand Up @@ -42,6 +43,14 @@ def convert_field_to_id(field, registry=None):
)


@convert_mongoengine_field.register(mongoengine.Decimal128Field)
@convert_mongoengine_field.register(mongoengine.DecimalField)
def convert_field_to_decimal(field, registry=None):
return graphene.Decimal(
description=get_field_description(field, registry), required=field.required
)


@convert_mongoengine_field.register(mongoengine.IntField)
@convert_mongoengine_field.register(mongoengine.LongField)
@convert_mongoengine_field.register(mongoengine.SequenceField)
Expand All @@ -58,21 +67,13 @@ def convert_field_to_boolean(field, registry=None):
)


@convert_mongoengine_field.register(mongoengine.DecimalField)
@convert_mongoengine_field.register(mongoengine.FloatField)
def convert_field_to_float(field, registry=None):
return graphene.Float(
description=get_field_description(field, registry), required=field.required
)


@convert_mongoengine_field.register(mongoengine.Decimal128Field)
def convert_field_to_decimal(field, registry=None):
return graphene.Decimal(
description=get_field_description(field, registry), required=field.required
)


@convert_mongoengine_field.register(mongoengine.DateTimeField)
def convert_field_to_datetime(field, registry=None):
return graphene.DateTime(
Expand Down Expand Up @@ -246,7 +247,7 @@ def convert_field_to_union(field, registry=None):
Meta = type("Meta", (object,), {"types": tuple(_types)})
_union = type(name, (graphene.Union,), {"Meta": Meta})

def reference_resolver(root, *args, **kwargs):
async def reference_resolver(root, *args, **kwargs):
de_referenced = getattr(root, field.name or field.db_name)
if de_referenced:
document = get_document(de_referenced["_cls"])
Expand All @@ -265,13 +266,14 @@ def reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in document._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return document.objects().no_dereference().only(*list(
set(list(_type._meta.required_fields) + queried_fields))).get(
pk=de_referenced["_ref"].id)
return document()
return await sync_to_async(document.objects().no_dereference().only(*list(
set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=de_referenced["_ref"].id)
return await sync_to_async(document, thread_sensitive=False,
executor=ThreadPoolExecutor())()
return None

def lazy_reference_resolver(root, *args, **kwargs):
async def lazy_reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
Expand All @@ -288,10 +290,11 @@ def lazy_reference_resolver(root, *args, **kwargs):
if item in document.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
_type = registry.get_type_for_model(document.document_type)
return document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
pk=document.pk)
return document.document_type()
return await sync_to_async(document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.pk)
return await sync_to_async(document.document_type, thread_sensitive=False,
executor=ThreadPoolExecutor())()
return None

if isinstance(field, mongoengine.GenericLazyReferenceField):
Expand Down Expand Up @@ -327,7 +330,7 @@ def lazy_reference_resolver(root, *args, **kwargs):
def convert_field_to_dynamic(field, registry=None):
model = field.document_type

def reference_resolver(root, *args, **kwargs):
async def reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
Expand All @@ -341,12 +344,12 @@ def reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in field.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return field.document_type.objects().no_dereference().only(
*(set(list(_type._meta.required_fields) + queried_fields))).get(
pk=document.id)
return await sync_to_async(field.document_type.objects().no_dereference().only(
*(set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.id)
return None

def cached_reference_resolver(root, *args, **kwargs):
async def cached_reference_resolver(root, *args, **kwargs):
if field:
queried_fields = list()
_type = registry.get_type_for_model(field.document_type)
Expand All @@ -359,9 +362,10 @@ def cached_reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in field.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return field.document_type.objects().no_dereference().only(
return await sync_to_async(field.document_type.objects().no_dereference().only(
*(set(
list(_type._meta.required_fields) + queried_fields))).get(
list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(
pk=getattr(root, field.name or field.db_name))
return None

Expand Down Expand Up @@ -394,7 +398,7 @@ def dynamic_type():
def convert_lazy_field_to_dynamic(field, registry=None):
model = field.document_type

def lazy_resolver(root, *args, **kwargs):
async def lazy_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
Expand All @@ -408,9 +412,9 @@ def lazy_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in document.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
pk=document.pk)
return await sync_to_async(document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.pk)
return None

def dynamic_type():
Expand Down
52 changes: 26 additions & 26 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import absolute_import

import logging
from collections import OrderedDict
from functools import partial, reduce

Expand All @@ -22,6 +21,8 @@
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure
from asgiref.sync import sync_to_async
from concurrent.futures import ThreadPoolExecutor

from .advanced_types import (
FileFieldType,
Expand Down Expand Up @@ -314,7 +315,7 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
skip)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
if required_fields is None:
required_fields = list()
args = args or {}
Expand Down Expand Up @@ -357,7 +358,8 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a

if isinstance(items, QuerySet):
try:
count = items.count(with_limit_and_skip=True)
count = await sync_to_async(items.count, thread_sensitive=False,
executor=ThreadPoolExecutor())(with_limit_and_skip=True)
except OperationFailure:
count = len(items)
else:
Expand Down Expand Up @@ -400,12 +402,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
args_copy[key] = args_copy[key].value

if PYMONGO_VERSION >= (3, 7):
if hasattr(self.model, '_meta') and 'db_alias' in self.model._meta:
count = (mongoengine.get_db(self.model._meta['db_alias'])[self.model._get_collection_name()]).count_documents(args_copy)
else:
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
count = await sync_to_async(
(mongoengine.get_db()[self.model._get_collection_name()]).count_documents,
thread_sensitive=False,
executor=ThreadPoolExecutor())(args_copy)
else:
count = self.model.objects(args_copy).count()
count = await sync_to_async(self.model.objects(args_copy).count, thread_sensitive=False,
executor=ThreadPoolExecutor())()
if count != 0:
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
count=count)
Expand Down Expand Up @@ -467,7 +470,7 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
connection.list_length = list_length
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):
async def chained_resolver(self, resolver, is_partial, root, info, **args):

for key, value in dict(args).items():
if value is None:
Expand Down Expand Up @@ -511,13 +514,13 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
elif not isinstance(resolved[0], DBRef):
return resolved
else:
return self.default_resolver(root, info, required_fields, **args_copy)
return await self.default_resolver(root, info, required_fields, **args_copy)
elif isinstance(resolved, QuerySet):
args.update(resolved._query)
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if "." in arg_name or arg_name not in self.model._fields_ordered \
+ ('first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
if "." in arg_name or arg_name not in self.model._fields_ordered + (
'first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if arg_name == '_id' and isinstance(arg, dict):
operation = list(arg.keys())[0]
Expand All @@ -537,38 +540,35 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
operation = list(arg.keys())[0]
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
del args_copy[arg_name]
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)

return await self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
elif isinstance(resolved, Promise):
return resolved.value
else:
return resolved
return await resolved

return self.default_resolver(root, info, required_fields, **args)
return await self.default_resolver(root, info, required_fields, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
async def connection_resolver(cls, resolver, connection_type, root, info, **args):
if root:
for key, value in root.__dict__.items():
if value:
try:
setattr(root, key, from_global_id(value)[1])
except Exception as error:
logging.error("Exception Occurred: ", exc_info=error)
iterable = resolver(root, info, **args)

except Exception:
pass
iterable = await resolver(root, info, **args)
if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)
return await sync_to_async(cls.resolve_connection, thread_sensitive=False,
executor=ThreadPoolExecutor())(connection_type, args, iterable)

def get_resolver(self, parent_resolver):
super_resolver = self.resolver or parent_resolver
resolver = partial(
self.chained_resolver, super_resolver, isinstance(super_resolver, partial)
)

return partial(self.connection_resolver, resolver, self.type)
4 changes: 2 additions & 2 deletions graphene_mongo/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_should_boolean_convert_boolean():
assert_conversion(mongoengine.BooleanField, graphene.Boolean)


def test_should_decimal_convert_float():
assert_conversion(mongoengine.DecimalField, graphene.Float)
def test_should_decimal_convert_decimal():
assert_conversion(mongoengine.DecimalField, graphene.Decimal)


def test_should_float_convert_float():
Expand Down
8 changes: 4 additions & 4 deletions graphene_mongo/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_field_args_with_unconverted_field():
assert set(field.field_args.keys()) == set(field_args)


def test_default_resolver_with_colliding_objects_field():
async def test_default_resolver_with_colliding_objects_field():
field = MongoengineConnectionField(nodes.ErroneousModelNode)

connection = field.default_resolver(None, {})
connection = await field.default_resolver(None, {})
assert 0 == len(connection.iterable)


def test_default_resolver_connection_list_length(fixtures):
async def test_default_resolver_connection_list_length(fixtures):
field = MongoengineConnectionField(nodes.ArticleNode)

connection = field.default_resolver(None, {}, **{"first": 1})
connection = await field.default_resolver(None, {}, **{"first": 1})
assert hasattr(connection, "list_length")
assert connection.list_length == 1
12 changes: 6 additions & 6 deletions graphene_mongo/tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from .types import ArticleInput, EditorInput


def test_should_create(fixtures):
async def test_should_create(fixtures):
class CreateArticle(graphene.Mutation):
class Arguments:
article = ArticleInput(required=True)

article = graphene.Field(ArticleNode)

def mutate(self, info, article):
async def mutate(self, info, article):
article = Article(**article)
article.save()

Expand All @@ -39,20 +39,20 @@ class Mutation(graphene.ObjectType):
"""
expected = {"createArticle": {"article": {"headline": "My Article"}}}
schema = graphene.Schema(query=Query, mutation=Mutation)
result = schema.execute(query)
result = await schema.execute_async(query)
assert not result.errors
assert result.data == expected


def test_should_update(fixtures):
async def test_should_update(fixtures):
class UpdateEditor(graphene.Mutation):
class Arguments:
id = graphene.ID(required=True)
editor = EditorInput(required=True)

editor = graphene.Field(EditorNode)

def mutate(self, info, id, editor):
async def mutate(self, info, id, editor):
editor_to_update = Editor.objects.get(id=id)
for key, value in editor.items():
if value:
Expand Down Expand Up @@ -85,7 +85,7 @@ class Mutation(graphene.ObjectType):
"""
expected = {"updateEditor": {"editor": {"firstName": "Penny", "lastName": "Lane"}}}
schema = graphene.Schema(query=Query, mutation=Mutation)
result = schema.execute(query)
result = await schema.execute_async(query)
# print(result.data)
assert not result.errors
assert result.data == expected
Loading