Skip to content

Commit 81746fb

Browse files
Support Async
1 parent ec1c7af commit 81746fb

13 files changed

+185
-770
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ query = '''
6262
}
6363
}
6464
'''
65-
result = schema.execute(query)
65+
result = await schema.execute_async(query)
6666
```
6767

6868
To learn more check out the following [examples](examples/):

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Then you can simply query the schema:
7171
}
7272
}
7373
'''
74-
result = schema.execute(query)
74+
result = await schema.execute_async(query)
7575
7676
To learn more check out the `Flask MongoEngine example <https://github.com/graphql-python/graphene-mongo/tree/master/examples/flask_mongoengine>`__
7777

examples/falcon_mongoengine/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def on_post(self, req, resp):
2323
class GraphQLResource:
2424
def on_get(self, req, resp):
2525
query = req.params["query"]
26-
result = schema.execute(query)
26+
result = await schema.execute_async(query)
2727

2828
if result.data:
2929
data_ret = {"data": result.data}
@@ -32,7 +32,7 @@ def on_get(self, req, resp):
3232

3333
def on_post(self, req, resp):
3434
query = req.params["query"]
35-
result = schema.execute(query)
35+
result = await schema.execute_async(query)
3636
if result.data:
3737
data_ret = {"data": result.data}
3838
resp.status = falcon.HTTP_200

graphene_mongo/converter.py

+33-29
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import advanced_types
1010
from .utils import import_single_dispatch, get_field_description, get_query_fields
1111
from concurrent.futures import ThreadPoolExecutor, as_completed
12+
from asgiref.sync import sync_to_async
1213

1314
singledispatch = import_single_dispatch()
1415

@@ -42,6 +43,14 @@ def convert_field_to_id(field, registry=None):
4243
)
4344

4445

46+
@convert_mongoengine_field.register(mongoengine.Decimal128Field)
47+
@convert_mongoengine_field.register(mongoengine.DecimalField)
48+
def convert_field_to_decimal(field, registry=None):
49+
return graphene.Decimal(
50+
description=get_field_description(field, registry), required=field.required
51+
)
52+
53+
4554
@convert_mongoengine_field.register(mongoengine.IntField)
4655
@convert_mongoengine_field.register(mongoengine.LongField)
4756
@convert_mongoengine_field.register(mongoengine.SequenceField)
@@ -58,21 +67,13 @@ def convert_field_to_boolean(field, registry=None):
5867
)
5968

6069

61-
@convert_mongoengine_field.register(mongoengine.DecimalField)
6270
@convert_mongoengine_field.register(mongoengine.FloatField)
6371
def convert_field_to_float(field, registry=None):
6472
return graphene.Float(
6573
description=get_field_description(field, registry), required=field.required
6674
)
6775

6876

69-
@convert_mongoengine_field.register(mongoengine.Decimal128Field)
70-
def convert_field_to_decimal(field, registry=None):
71-
return graphene.Decimal(
72-
description=get_field_description(field, registry), required=field.required
73-
)
74-
75-
7677
@convert_mongoengine_field.register(mongoengine.DateTimeField)
7778
def convert_field_to_datetime(field, registry=None):
7879
return graphene.DateTime(
@@ -246,7 +247,7 @@ def convert_field_to_union(field, registry=None):
246247
Meta = type("Meta", (object,), {"types": tuple(_types)})
247248
_union = type(name, (graphene.Union,), {"Meta": Meta})
248249

249-
def reference_resolver(root, *args, **kwargs):
250+
async def reference_resolver(root, *args, **kwargs):
250251
de_referenced = getattr(root, field.name or field.db_name)
251252
if de_referenced:
252253
document = get_document(de_referenced["_cls"])
@@ -265,13 +266,14 @@ def reference_resolver(root, *args, **kwargs):
265266
item = to_snake_case(each)
266267
if item in document._fields_ordered + tuple(filter_args):
267268
queried_fields.append(item)
268-
return document.objects().no_dereference().only(*list(
269-
set(list(_type._meta.required_fields) + queried_fields))).get(
270-
pk=de_referenced["_ref"].id)
271-
return document()
269+
return await sync_to_async(document.objects().no_dereference().only(*list(
270+
set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
271+
executor=ThreadPoolExecutor())(pk=de_referenced["_ref"].id)
272+
return await sync_to_async(document, thread_sensitive=False,
273+
executor=ThreadPoolExecutor())()
272274
return None
273275

274-
def lazy_reference_resolver(root, *args, **kwargs):
276+
async def lazy_reference_resolver(root, *args, **kwargs):
275277
document = getattr(root, field.name or field.db_name)
276278
if document:
277279
queried_fields = list()
@@ -288,10 +290,11 @@ def lazy_reference_resolver(root, *args, **kwargs):
288290
if item in document.document_type._fields_ordered + tuple(filter_args):
289291
queried_fields.append(item)
290292
_type = registry.get_type_for_model(document.document_type)
291-
return document.document_type.objects().no_dereference().only(
292-
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
293-
pk=document.pk)
294-
return document.document_type()
293+
return await sync_to_async(document.document_type.objects().no_dereference().only(
294+
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
295+
executor=ThreadPoolExecutor())(pk=document.pk)
296+
return await sync_to_async(document.document_type, thread_sensitive=False,
297+
executor=ThreadPoolExecutor())()
295298
return None
296299

297300
if isinstance(field, mongoengine.GenericLazyReferenceField):
@@ -327,7 +330,7 @@ def lazy_reference_resolver(root, *args, **kwargs):
327330
def convert_field_to_dynamic(field, registry=None):
328331
model = field.document_type
329332

330-
def reference_resolver(root, *args, **kwargs):
333+
async def reference_resolver(root, *args, **kwargs):
331334
document = getattr(root, field.name or field.db_name)
332335
if document:
333336
queried_fields = list()
@@ -341,12 +344,12 @@ def reference_resolver(root, *args, **kwargs):
341344
item = to_snake_case(each)
342345
if item in field.document_type._fields_ordered + tuple(filter_args):
343346
queried_fields.append(item)
344-
return field.document_type.objects().no_dereference().only(
345-
*(set(list(_type._meta.required_fields) + queried_fields))).get(
346-
pk=document.id)
347+
return await sync_to_async(field.document_type.objects().no_dereference().only(
348+
*(set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
349+
executor=ThreadPoolExecutor())(pk=document.id)
347350
return None
348351

349-
def cached_reference_resolver(root, *args, **kwargs):
352+
async def cached_reference_resolver(root, *args, **kwargs):
350353
if field:
351354
queried_fields = list()
352355
_type = registry.get_type_for_model(field.document_type)
@@ -359,9 +362,10 @@ def cached_reference_resolver(root, *args, **kwargs):
359362
item = to_snake_case(each)
360363
if item in field.document_type._fields_ordered + tuple(filter_args):
361364
queried_fields.append(item)
362-
return field.document_type.objects().no_dereference().only(
365+
return await sync_to_async(field.document_type.objects().no_dereference().only(
363366
*(set(
364-
list(_type._meta.required_fields) + queried_fields))).get(
367+
list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
368+
executor=ThreadPoolExecutor())(
365369
pk=getattr(root, field.name or field.db_name))
366370
return None
367371

@@ -394,7 +398,7 @@ def dynamic_type():
394398
def convert_lazy_field_to_dynamic(field, registry=None):
395399
model = field.document_type
396400

397-
def lazy_resolver(root, *args, **kwargs):
401+
async def lazy_resolver(root, *args, **kwargs):
398402
document = getattr(root, field.name or field.db_name)
399403
if document:
400404
queried_fields = list()
@@ -408,9 +412,9 @@ def lazy_resolver(root, *args, **kwargs):
408412
item = to_snake_case(each)
409413
if item in document.document_type._fields_ordered + tuple(filter_args):
410414
queried_fields.append(item)
411-
return document.document_type.objects().no_dereference().only(
412-
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
413-
pk=document.pk)
415+
return await sync_to_async(document.document_type.objects().no_dereference().only(
416+
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
417+
executor=ThreadPoolExecutor())(pk=document.pk)
414418
return None
415419

416420
def dynamic_type():

graphene_mongo/fields.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import absolute_import
22

3-
import logging
43
from collections import OrderedDict
54
from functools import partial, reduce
65

@@ -22,6 +21,8 @@
2221
from mongoengine.base import get_document
2322
from promise import Promise
2423
from pymongo.errors import OperationFailure
24+
from asgiref.sync import sync_to_async
25+
from concurrent.futures import ThreadPoolExecutor
2526

2627
from .advanced_types import (
2728
FileFieldType,
@@ -314,7 +315,7 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
314315
skip)
315316
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)
316317

317-
def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
318+
async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
318319
if required_fields is None:
319320
required_fields = list()
320321
args = args or {}
@@ -357,7 +358,8 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
357358

358359
if isinstance(items, QuerySet):
359360
try:
360-
count = items.count(with_limit_and_skip=True)
361+
count = await sync_to_async(items.count, thread_sensitive=False,
362+
executor=ThreadPoolExecutor())(with_limit_and_skip=True)
361363
except OperationFailure:
362364
count = len(items)
363365
else:
@@ -400,12 +402,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
400402
args_copy[key] = args_copy[key].value
401403

402404
if PYMONGO_VERSION >= (3, 7):
403-
if hasattr(self.model, '_meta') and 'db_alias' in self.model._meta:
404-
count = (mongoengine.get_db(self.model._meta['db_alias'])[self.model._get_collection_name()]).count_documents(args_copy)
405-
else:
406-
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
405+
count = await sync_to_async(
406+
(mongoengine.get_db()[self.model._get_collection_name()]).count_documents,
407+
thread_sensitive=False,
408+
executor=ThreadPoolExecutor())(args_copy)
407409
else:
408-
count = self.model.objects(args_copy).count()
410+
count = await sync_to_async(self.model.objects(args_copy).count, thread_sensitive=False,
411+
executor=ThreadPoolExecutor())()
409412
if count != 0:
410413
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
411414
count=count)
@@ -467,7 +470,7 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
467470
connection.list_length = list_length
468471
return connection
469472

470-
def chained_resolver(self, resolver, is_partial, root, info, **args):
473+
async def chained_resolver(self, resolver, is_partial, root, info, **args):
471474

472475
for key, value in dict(args).items():
473476
if value is None:
@@ -511,13 +514,13 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
511514
elif not isinstance(resolved[0], DBRef):
512515
return resolved
513516
else:
514-
return self.default_resolver(root, info, required_fields, **args_copy)
517+
return await self.default_resolver(root, info, required_fields, **args_copy)
515518
elif isinstance(resolved, QuerySet):
516519
args.update(resolved._query)
517520
args_copy = args.copy()
518521
for arg_name, arg in args.copy().items():
519-
if "." in arg_name or arg_name not in self.model._fields_ordered \
520-
+ ('first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
522+
if "." in arg_name or arg_name not in self.model._fields_ordered + (
523+
'first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
521524
args_copy.pop(arg_name)
522525
if arg_name == '_id' and isinstance(arg, dict):
523526
operation = list(arg.keys())[0]
@@ -537,38 +540,35 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
537540
operation = list(arg.keys())[0]
538541
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
539542
del args_copy[arg_name]
540-
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
543+
544+
return await self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
541545
elif isinstance(resolved, Promise):
542546
return resolved.value
543547
else:
544-
return resolved
548+
return await resolved
545549

546-
return self.default_resolver(root, info, required_fields, **args)
550+
return await self.default_resolver(root, info, required_fields, **args)
547551

548552
@classmethod
549-
def connection_resolver(cls, resolver, connection_type, root, info, **args):
553+
async def connection_resolver(cls, resolver, connection_type, root, info, **args):
550554
if root:
551555
for key, value in root.__dict__.items():
552556
if value:
553557
try:
554558
setattr(root, key, from_global_id(value)[1])
555-
except Exception as error:
556-
logging.error("Exception Occurred: ", exc_info=error)
557-
iterable = resolver(root, info, **args)
558-
559+
except Exception:
560+
pass
561+
iterable = await resolver(root, info, **args)
559562
if isinstance(connection_type, graphene.NonNull):
560563
connection_type = connection_type.of_type
561564

562-
on_resolve = partial(cls.resolve_connection, connection_type, args)
563-
564-
if Promise.is_thenable(iterable):
565-
return Promise.resolve(iterable).then(on_resolve)
566-
567-
return on_resolve(iterable)
565+
return await sync_to_async(cls.resolve_connection, thread_sensitive=False,
566+
executor=ThreadPoolExecutor())(connection_type, args, iterable)
568567

569568
def get_resolver(self, parent_resolver):
570569
super_resolver = self.resolver or parent_resolver
571570
resolver = partial(
572571
self.chained_resolver, super_resolver, isinstance(super_resolver, partial)
573572
)
573+
574574
return partial(self.connection_resolver, resolver, self.type)

graphene_mongo/tests/test_converter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def test_should_boolean_convert_boolean():
7070
assert_conversion(mongoengine.BooleanField, graphene.Boolean)
7171

7272

73-
def test_should_decimal_convert_float():
74-
assert_conversion(mongoengine.DecimalField, graphene.Float)
73+
def test_should_decimal_convert_decimal():
74+
assert_conversion(mongoengine.DecimalField, graphene.Decimal)
7575

7676

7777
def test_should_float_convert_float():

graphene_mongo/tests/test_fields.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def test_field_args_with_unconverted_field():
4444
assert set(field.field_args.keys()) == set(field_args)
4545

4646

47-
def test_default_resolver_with_colliding_objects_field():
47+
async def test_default_resolver_with_colliding_objects_field():
4848
field = MongoengineConnectionField(nodes.ErroneousModelNode)
4949

50-
connection = field.default_resolver(None, {})
50+
connection = await field.default_resolver(None, {})
5151
assert 0 == len(connection.iterable)
5252

5353

54-
def test_default_resolver_connection_list_length(fixtures):
54+
async def test_default_resolver_connection_list_length(fixtures):
5555
field = MongoengineConnectionField(nodes.ArticleNode)
5656

57-
connection = field.default_resolver(None, {}, **{"first": 1})
57+
connection = await field.default_resolver(None, {}, **{"first": 1})
5858
assert hasattr(connection, "list_length")
5959
assert connection.list_length == 1

graphene_mongo/tests/test_inputs.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from .types import ArticleInput, EditorInput
88

99

10-
def test_should_create(fixtures):
10+
async def test_should_create(fixtures):
1111
class CreateArticle(graphene.Mutation):
1212
class Arguments:
1313
article = ArticleInput(required=True)
1414

1515
article = graphene.Field(ArticleNode)
1616

17-
def mutate(self, info, article):
17+
async def mutate(self, info, article):
1818
article = Article(**article)
1919
article.save()
2020

@@ -39,20 +39,20 @@ class Mutation(graphene.ObjectType):
3939
"""
4040
expected = {"createArticle": {"article": {"headline": "My Article"}}}
4141
schema = graphene.Schema(query=Query, mutation=Mutation)
42-
result = schema.execute(query)
42+
result = await schema.execute_async(query)
4343
assert not result.errors
4444
assert result.data == expected
4545

4646

47-
def test_should_update(fixtures):
47+
async def test_should_update(fixtures):
4848
class UpdateEditor(graphene.Mutation):
4949
class Arguments:
5050
id = graphene.ID(required=True)
5151
editor = EditorInput(required=True)
5252

5353
editor = graphene.Field(EditorNode)
5454

55-
def mutate(self, info, id, editor):
55+
async def mutate(self, info, id, editor):
5656
editor_to_update = Editor.objects.get(id=id)
5757
for key, value in editor.items():
5858
if value:
@@ -85,7 +85,7 @@ class Mutation(graphene.ObjectType):
8585
"""
8686
expected = {"updateEditor": {"editor": {"firstName": "Penny", "lastName": "Lane"}}}
8787
schema = graphene.Schema(query=Query, mutation=Mutation)
88-
result = schema.execute(query)
88+
result = await schema.execute_async(query)
8989
# print(result.data)
9090
assert not result.errors
9191
assert result.data == expected

0 commit comments

Comments
 (0)