Skip to content

Commit

Permalink
Add support for python 3.10+ typing
Browse files Browse the repository at this point in the history
  • Loading branch information
tlconnor committed Oct 14, 2024
1 parent 98c1719 commit 69953a4
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 54 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v3

- name: Set up Python 3.9
- name: Set up Python 3.12
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.12

- name: Install dependencies
run: |
Expand All @@ -44,7 +44,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']

steps:
- name: Checkout repository
Expand Down
75 changes: 40 additions & 35 deletions hologram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
Generic,
Hashable,
ClassVar,
get_origin,
get_args,
)
from types import UnionType
import re
from datetime import datetime
from dataclasses import fields, is_dataclass, Field, MISSING, dataclass, asdict
Expand All @@ -37,8 +40,6 @@
JsonEncodable = Union[int, float, str, bool, None]
JsonDict = Dict[str, Any]

OPTIONAL_TYPES = ["Union", "Optional"]


class ValidationError(jsonschema.ValidationError):
pass
Expand Down Expand Up @@ -86,15 +87,12 @@ def issubclass_safe(klass: Any, base: Type) -> bool:
return False


def is_optional(field: Any) -> bool:
if str(field).startswith("typing.Union") or str(field).startswith(
"typing.Optional"
):
for arg in field.__args__:
if isinstance(arg, type) and issubclass(arg, type(None)):
return True
def is_union(field: Any) -> bool:
return get_origin(field) in [Union, UnionType]

return False

def is_optional(field: Any) -> bool:
return is_union(field) and (type(None) in get_args(field))


TV = TypeVar("TV")
Expand Down Expand Up @@ -233,7 +231,7 @@ def get_union_fields(field_type: Union[Any]) -> List[Variant]:
end.
"""
fields: List[Variant] = []
for variant in field_type.__args__:
for variant in get_args(field_type):
restrictions: Optional[Restriction] = _get_restrictions(variant)
if not restrictions:
restrictions = None
Expand All @@ -247,7 +245,7 @@ def get_union_fields(field_type: Union[Any]) -> List[Variant]:
def _encode_restrictions_met(
value: Any, restrict_fields: Optional[List[Tuple[Field, str]]]
) -> bool:
if restrict_fields is None:
if restrict_fields is None or len(restrict_fields) == 0:
return True
return all(
(
Expand Down Expand Up @@ -345,7 +343,7 @@ def encoder(ft, v, __):
def encoder(_, v, __):
return v.value

elif field_type_name in OPTIONAL_TYPES:
elif is_union(field_type):
# Attempt to encode the field with each union variant.
# TODO: Find a more reliable method than this since in the case 'Union[List[str], Dict[str, int]]' this
# will just output the dict keys as a list
Expand All @@ -367,7 +365,7 @@ def encoder(_, v, __):
)
)
return encoded
elif field_type_name in ("Mapping", "Dict"):
elif field_type_name in ("Mapping", "Dict", "dict"):

def encoder(ft, val, o):
return {
Expand All @@ -381,15 +379,17 @@ def encoder(ft, val, o):
# TODO: is there some way to set __args__ on this so it can
# just re-use Dict/Mapping?
def encoder(ft, val, o):

return {
cls._encode_field(str, k, o): cls._encode_field(
ft.TARGET_TYPE, v, o
)
for k, v in val.items()
}

elif field_type_name == "List" or (
field_type_name == "Tuple" and ... in field_type.__args__
elif field_type_name in ("List", "list") or (
field_type_name in ("Tuple", "tuple")
and ... in field_type.__args__
):

def encoder(ft, val, o):
Expand All @@ -410,7 +410,7 @@ def encoder(ft, val, o):
cls._encode_field(ft.__args__[0], v, o) for v in val
]

elif field_type_name == "Tuple":
elif field_type_name in ("Tuple", "tuple"):

def encoder(ft, val, o):
return [
Expand Down Expand Up @@ -517,7 +517,7 @@ def decoder(_, ft, val):
def decoder(_, ft, val):
return ft.from_dict(val, validate=validate)

elif field_type_name in OPTIONAL_TYPES:
elif is_union(field_type):
# Attempt to decode the value using each decoder in turn
union_excs = (
AttributeError,
Expand All @@ -541,7 +541,7 @@ def decoder(_, ft, val):
# none of the unions decoded, so report about all of them
raise FutureValidationError(field, errors)

elif field_type_name in ("Mapping", "Dict"):
elif field_type_name in ("Mapping", "Dict", "dict"):

def decoder(f, ft, val):
return {
Expand All @@ -551,10 +551,13 @@ def decoder(f, ft, val):
for k, v in val.items()
}

elif field_type_name == "List" or (
field_type_name == "Tuple" and ... in field_type.__args__
elif field_type_name in ("List", "list") or (
field_type_name in ("Tuple", "tuple")
and ... in field_type.__args__
):
seq_type = tuple if field_type_name == "Tuple" else list
seq_type = (
tuple if field_type_name in ("Tuple", "tuple") else list
)

def decoder(f, ft, val):
if not isinstance(val, (tuple, list)):
Expand All @@ -578,7 +581,7 @@ def decoder(f, ft, val):
for v in val
)

elif field_type_name == "Tuple":
elif field_type_name in ("Tuple", "tuple"):

def decoder(f, ft, val):
return tuple(
Expand Down Expand Up @@ -739,8 +742,10 @@ def _get_schema_for_type(
required: bool = True,
restrictions: Optional[List[Any]] = None,
) -> Tuple[JsonDict, bool]:

field_schema: JsonDict = {"type": "object"}

type_args = get_args(target)
type_name = cls._get_field_type_name(target)

if target in cls._field_encoders:
Expand All @@ -749,12 +754,12 @@ def _get_schema_for_type(
elif restrictions:
field_schema.update(cls._encode_restrictions(restrictions))

# if Union[..., None] or Optional[...]
elif type_name in OPTIONAL_TYPES:
# if ... | None, Union[..., None] or Optional[...]
elif is_union(target):
field_schema = {
"oneOf": [
cls._get_field_schema(variant)[0]
for variant in target.__args__
for variant in get_args(target)
]
}

Expand All @@ -764,7 +769,7 @@ def _get_schema_for_type(
elif is_enum(target):
field_schema.update(cls._encode_restrictions(target))

elif type_name in ("Dict", "Mapping"):
elif type_name in ("Dict", "dict", "Mapping"):
field_schema = {"type": "object"}
if target.__args__[1] is not Any:
field_schema["additionalProperties"] = cls._get_field_schema(
Expand All @@ -776,16 +781,16 @@ def _get_schema_for_type(
".*": cls._get_field_schema(target.TARGET_TYPE)[0]
}

elif type_name in ("Sequence", "List") or (
type_name == "Tuple" and ... in target.__args__
elif type_name in ("Sequence", "List", "list") or (
type_name in ("Tuple", "tuple") and ... in target.__args__
):
field_schema = {"type": "array"}
if target.__args__[0] is not Any:
field_schema["items"] = cls._get_field_schema(
target.__args__[0]
)[0]

elif type_name == "Tuple":
elif type_name in ("Tuple", "tuple"):
tuple_len = len(target.__args__)
# TODO: How do we handle Optional type within lists / tuples
field_schema = {
Expand Down Expand Up @@ -842,19 +847,19 @@ def _get_field_schema(
@classmethod
def _get_field_definitions(cls, field_type: Any, definitions: JsonDict):
field_type_name = cls._get_field_type_name(field_type)
if field_type_name == "Tuple":
if field_type_name in ("Tuple", "tuple"):
# tuples are either like Tuple[T, ...] or Tuple[T1, T2, T3].
for member in field_type.__args__:
if member is not ...:
cls._get_field_definitions(member, definitions)
elif field_type_name in ("Sequence", "List"):
elif field_type_name in ("Sequence", "List", "list"):
cls._get_field_definitions(field_type.__args__[0], definitions)
elif field_type_name in ("Dict", "Mapping"):
elif field_type_name in ("Dict", "dict", "Mapping"):
cls._get_field_definitions(field_type.__args__[1], definitions)
elif field_type_name == "PatternProperty":
cls._get_field_definitions(field_type.TARGET_TYPE, definitions)
elif field_type_name in OPTIONAL_TYPES:
for variant in field_type.__args__:
elif is_union(field_type):
for variant in get_args(field_type):
cls._get_field_definitions(variant, definitions)
elif cls._is_json_schema_subclass(field_type):
# Prevent recursion from forward refs & circular type dependencies
Expand Down
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def read(f):
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Libraries",
],
)
2 changes: 1 addition & 1 deletion tests/test_dict_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SecondDictFieldValue(JsonSchemaMixin):
class HasDictFields(JsonSchemaMixin):
a: str
x: Dict[str, str]
z: Dict[str, Union[DictFieldValue, SecondDictFieldValue]]
z: dict[str, Union[DictFieldValue, SecondDictFieldValue]]


def test_schema():
Expand Down
11 changes: 8 additions & 3 deletions tests/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class TupleMember(JsonSchemaMixin):

@dataclass
class TupleEllipsisHolder(JsonSchemaMixin):
member: Tuple[TupleMember, ...]
member1: Tuple[TupleMember, ...]
member2: tuple[TupleMember, ...]


@dataclass
Expand All @@ -25,9 +26,13 @@ class TupleMemberSecondHolder(JsonSchemaMixin):


def test_ellipsis_tuples():
dct = {"member": [{"a": 1}, {"a": 2}, {"a": 3}]}
dct = {
"member1": [{"a": 1}, {"a": 2}, {"a": 3}],
"member2": [{"a": 1}, {"a": 2}, {"a": 3}],
}
value = TupleEllipsisHolder(
member=(TupleMember(1), TupleMember(2), TupleMember(3))
member1=(TupleMember(1), TupleMember(2), TupleMember(3)),
member2=(TupleMember(1), TupleMember(2), TupleMember(3)),
)
assert value.to_dict() == dct
assert TupleEllipsisHolder.from_dict(dct) == value
Expand Down
17 changes: 9 additions & 8 deletions tests/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

@dataclass
class IHaveAnnoyingUnions(JsonSchemaMixin):
my_field: Optional[Union[List[str], str]]
my_field1: list[str] | str | None
my_field2: Optional[Union[List[str], str]]


@dataclass
Expand All @@ -19,15 +20,12 @@ class IHaveAnnoyingUnionsReversed(JsonSchemaMixin):

def test_union_decoding():
for field_value in (None, [">=0.0.0"], ">=0.0.0"):
obj = IHaveAnnoyingUnions(my_field=field_value)
dct = {"my_field": field_value}
obj = IHaveAnnoyingUnions(my_field1=field_value, my_field2=field_value)
dct = {"my_field1": field_value, "my_field2": field_value}
decoded = IHaveAnnoyingUnions.from_dict(dct)
assert decoded == obj
assert obj.to_dict(omit_none=False) == dct

# this is allowed, for backwards-compatibility reasons
IHaveAnnoyingUnions(my_field=(">=0.0.0",)) == {"my_field": (">=0.0.0",)}


def test_union_decoding_ordering():
for field_value in (None, [">=0.0.0"], ">=0.0.0"):
Expand All @@ -44,12 +42,15 @@ def test_union_decoding_ordering():


def test_union_decode_error():
x = IHaveAnnoyingUnions(my_field={">=0.0.0"})
x = IHaveAnnoyingUnions(my_field1={">=0.0.0"}, my_field2={">=0.0.0"})
with pytest.raises(ValidationError):
x.to_dict(validate=True)

with pytest.raises(ValidationError):
IHaveAnnoyingUnions.from_dict({"my_field": {">=0.0.0"}})
IHaveAnnoyingUnions.from_dict({"my_field1": {">=0.0.0"}})

with pytest.raises(ValidationError):
IHaveAnnoyingUnions.from_dict({"my_field2": {">=0.0.0"}})


@dataclass
Expand Down

0 comments on commit 69953a4

Please sign in to comment.