Skip to content
This repository has been archived by the owner on May 18, 2024. It is now read-only.

fix: supporting model subclasses #22

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
149 changes: 126 additions & 23 deletions graphene_pynamodb/converter.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import inspect
import json
from collections import OrderedDict
from functools import partial

from graphene import Dynamic, Field, Float
from graphene import ID, Boolean, List, String
from graphene.types.json import JSONString
from pynamodb import attributes
from singledispatch import singledispatch

from graphene import ID, Boolean, Dynamic, Field, Float, Int, List, String
from graphene.types import ObjectType
from graphene.types.json import JSONString
from graphene.types.resolver import default_resolver
from graphene_pynamodb import relationships
from graphene_pynamodb.fields import PynamoConnectionField
from graphene_pynamodb.relationships import OneToOne, OneToMany
from graphene_pynamodb.registry import Registry
from graphene_pynamodb.relationships import OneToMany, OneToOne


@singledispatch
def convert_pynamo_attribute(type, attribute, registry=None):
raise Exception(
"Don't know how to convert the PynamoDB attribute %s (%s)" % (attribute, attribute.__class__))
"Don't know how to convert the PynamoDB attribute %s (%s)"
% (attribute, attribute.__class__)
)


@convert_pynamo_attribute.register(attributes.BinaryAttribute)
Expand All @@ -23,14 +30,18 @@ def convert_column_to_string(type, attribute, registry=None):
if attribute.is_hash_key:
return ID(description=attribute.attr_name, required=not attribute.null)

return String(description=getattr(attribute, 'attr_name'),
required=not (getattr(attribute, 'null', True)))
return String(
description=getattr(attribute, "attr_name"),
required=not (getattr(attribute, "null", True)),
)


@convert_pynamo_attribute.register(attributes.UTCDateTimeAttribute)
def convert_date_to_string(type, attribute, registry=None):
return String(description=getattr(attribute, 'attr_name'),
required=not (getattr(attribute, 'null', True)))
return String(
description=getattr(attribute, "attr_name"),
required=not (getattr(attribute, "null", True)),
)


@convert_pynamo_attribute.register(relationships.Relationship)
Expand Down Expand Up @@ -77,7 +88,7 @@ def convert_json_to_string(type, attribute, registry=None):


class MapToJSONString(JSONString):
'''JSON String Converter for MapAttribute'''
"""JSON String Converter for MapAttribute"""

@staticmethod
def serialize(dt):
Expand All @@ -98,25 +109,117 @@ def serialize(dt):
return dt


def map_attribute_to_object_type(attribute, registry: Registry):
if not hasattr(registry, "map_attr_types"):
registry.map_attr_types = {}
if attribute in registry.map_attr_types:
return registry.map_attr_types[attribute]

fields = OrderedDict()
for name, attr in attribute.get_attributes().items():
fields[name] = convert_pynamo_attribute(attr, attr, registry)

map_attribute_type = type(
f"MapAttribute_{attribute.__name__}",
(ObjectType,),
fields,
)

registry.map_attr_types[attribute] = map_attribute_type
return map_attribute_type


@convert_pynamo_attribute.register(attributes.MapAttribute)
def convert_map_to_json(type, attribute, registry=None):
def convert_map_to_object_type(attribute, _, registry=None):
try:
name = attribute.attr_name
except (KeyError, AttributeError):
name = "MapAttribute"
required = not attribute.null if hasattr(attribute, 'null') else False
return MapToJSONString(description=name, required=required)
required = not attribute.null if hasattr(attribute, "null") else False
return map_attribute_to_object_type(attribute, registry)(
description=name, required=required
)


def list_resolver(attname, default_value):
def _resolver(
parent,
info,
index: int = None,
start_index: int = None,
end_index: int = None,
**kwargs,
):

data = default_resolver(
attname=attname,
default_value=default_value,
root=parent,
info=info,
**kwargs,
)
if index is not None:
return [data[index]]
if (start_index is not None) and (end_index is not None):
return data[start_index:end_index]
if start_index is not None:
return data[start_index:]
if end_index is not None:
return data[:end_index]
return data

return _resolver


def get_list_field_kwargs(attribute):
try:
name = attribute.attr_name
default = attribute.default
except KeyError:
name = attribute.element_type.__name__
default = None
return (
dict(
index=Int(description="Return element at the position"),
start_index=Int(description="Start of the slice of the list"),
end_index=Int(
description="End of the slice of the list. Negative numbers can be given to access from the end."
),
resolver=list_resolver(name, default),
),
name,
)


default_fields_mapping = {
attributes.NumberAttribute: Float,
attributes.BooleanAttribute: Boolean,
}


def get_list_field_converter(attribute, registry, mapping: dict = None):
mapping = mapping or {}
kwargs, name = get_list_field_kwargs(attribute)
if attribute.element_type and inspect.isclass(attribute.element_type):

required = not attribute.null if hasattr(attribute, "null") else False
cls = None

for elem_type, mapped in default_fields_mapping.items():
if issubclass(attribute.element_type, elem_type):
cls = mapping.get(elem_type, mapped)

if cls is None:
if issubclass(attribute.element_type, attributes.MapAttribute):
cls = map_attribute_to_object_type(attribute.element_type, registry)
else:
cls = String

return List(cls, description=name, required=required, **kwargs)
else:
return List(String, description=name, **kwargs)


@convert_pynamo_attribute.register(attributes.ListAttribute)
def convert_list_to_list(type, attribute, registry=None):
if attribute.element_type:
try:
name = attribute.attr_name
except KeyError:
name = "MapAttribute"

required = not attribute.null if hasattr(attribute, 'null') else False
return ListOfMapToObject(description=name, required=required)
else:
return List(String, description=attribute.attr_name)
return get_list_field_converter(attribute, registry)
95 changes: 61 additions & 34 deletions graphene_pynamodb/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@

from functools import partial

from graphene import Int
from graphene import relay
from graphene.relay.connection import PageInfo
from graphql_relay import from_global_id
from graphql_relay import to_global_id
from graphql_relay.connection.connectiontypes import Edge

from graphene import Int, relay
from graphene.relay.connection import PageInfo
from graphene_pynamodb.relationships import RelationshipResultList
from graphene_pynamodb.utils import get_key_name
from graphene_pynamodb.utils import from_cursor, get_key_name, to_cursor


class PynamoConnectionField(relay.ConnectionField):
total_count = Int()

def __init__(self, type, *args, **kwargs):
super(PynamoConnectionField, self).__init__(
type._meta.connection,
*args,
**kwargs
type._meta.connection, *args, **kwargs
)

@property
Expand All @@ -36,81 +31,113 @@ def get_query(cls, model, info, **args):
def connection_resolver(cls, resolver, connection, model, root, info, **args):
iterable = resolver(root, info, **args)

first = args.get('first')
last = args.get('last')
(_, after) = from_global_id(args.get('after')) if args.get('after') else (None, None)
(_, before) = from_global_id(args.get('before')) if args.get('before') else (None, None)
first = args.get("first")
last = args.get("last")
(_, after) = (
from_cursor(args.get("after")) if args.get("after") else (None, None)
)
(_, before) = (
from_cursor(args.get("before")) if args.get("before") else (None, None)
)
has_previous_page = bool(after)
page_size = first if first else last if last else None

# get a full scan query since we have no resolved iterable from relationship or resolver function
if not iterable and not root:
query = cls.get_query(model, info, **args)
iterable = query()
if first or last or after or before:
raise NotImplementedError(
"DynamoDB scan operations have no predictable sort. Arguments first, last, after " +
"and before will have unpredictable results")

iterable = iterable if isinstance(iterable, list) else list(iterable) if iterable else []
query_params = dict(limit=page_size or 20, consistent_read=True)
if after:
query_params["last_evaluated_key"] = after

result_iterator = query(**query_params)
iterable = list(result_iterator)
# if first or last or after or before:
# raise NotImplementedError(
# "DynamoDB scan operations have no predictable sort. Arguments first, last, after "
# + "and before will have unpredictable results"
# )

iterable = (
iterable
if isinstance(iterable, list)
else list(iterable)
if iterable
else []
)
if last:
iterable = iterable[-last:]

(has_next, edges) = cls.get_edges_from_iterable(iterable, model, info, edge_type=connection.Edge, after=after,
page_size=page_size)
(has_next, edges) = cls.get_edges_from_iterable(
iterable,
model,
info,
edge_type=connection.Edge,
# after=after,
page_size=page_size,
)

key_name = get_key_name(model)
try:
start_cursor = getattr(edges[0].node, key_name)
end_cursor = getattr(edges[-1].node, key_name)
start_cursor = to_cursor(iterable[0])
end_cursor = to_cursor(iterable[-1])
except IndexError:
start_cursor = None
end_cursor = None

optional_args = {}
total_count = len(iterable)
if 'total_count' in connection._meta.fields:
if "total_count" in connection._meta.fields:
optional_args["total_count"] = total_count

# Construct the connection
return connection(
edges=edges,
page_info=PageInfo(
start_cursor=start_cursor if start_cursor else '',
end_cursor=end_cursor if end_cursor else '',
start_cursor=start_cursor if start_cursor else "",
end_cursor=end_cursor if end_cursor else "",
has_previous_page=has_previous_page,
has_next_page=has_next
has_next_page=has_next,
),
**optional_args
**optional_args,
)

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)

@classmethod
def get_edges_from_iterable(cls, iterable, model, info, edge_type=Edge, after=None, page_size=None):
def get_edges_from_iterable(
cls, iterable, model, info, edge_type=Edge, after=None, page_size=None
):
has_next = False

key_name = get_key_name(model)
after_index = 0
if after:
after_index = next((i for i, item in enumerate(iterable) if str(getattr(item, key_name)) == after), None)
after_index = next(
(
i
for i, item in enumerate(iterable)
if str(getattr(item, key_name)) == after
),
None,
)
if after_index is None:
return None
else:
after_index += 1

if page_size:
has_next = len(iterable) - after_index > page_size
iterable = iterable[after_index:after_index + page_size]
iterable = iterable[after_index : after_index + page_size]
else:
iterable = iterable[after_index:]

# trigger a batch get to speed up query instead of relying on lazy individual gets
if isinstance(iterable, RelationshipResultList):
iterable = iterable.resolve()

edges = [edge_type(node=entity, cursor=to_global_id(model.__name__, getattr(entity, key_name)))
for entity in iterable]
edges = [
edge_type(node=entity, cursor=to_cursor(entity)) for entity in iterable
]

return [has_next, edges]
8 changes: 3 additions & 5 deletions graphene_pynamodb/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ def get_model_fields(model, excluding=None):
if excluding is None:
excluding = []
attributes = dict()
for attr_name in vars(model):

for attr_name, attr in model.get_attributes().items():
if attr_name in excluding:
continue
attr = getattr(model, attr_name)
if isinstance(attr, Attribute):
attributes[attr_name] = attr

attributes[attr_name] = attr
return OrderedDict(sorted(attributes.items(), key=lambda t: t[0]))


Expand Down
Loading