diff --git a/rest_framework_dataclasses/field_utils.py b/rest_framework_dataclasses/field_utils.py index 1fa0e31..bbeeb8e 100644 --- a/rest_framework_dataclasses/field_utils.py +++ b/rest_framework_dataclasses/field_utils.py @@ -75,7 +75,7 @@ def get_type_info(tp: type) -> TypeInfo: tp = typing_utils.get_iterable_element_type(tp) if typing_utils.is_type_variable(tp): - tp = typing_utils.get_variable_type_substitute(tp) + tp = typing_utils.get_type_variable_substitution(tp) return TypeInfo(is_many, is_mapping, is_final, is_nullable, tp, cp) diff --git a/rest_framework_dataclasses/typing_utils.py b/rest_framework_dataclasses/typing_utils.py index bb42526..19b70f0 100644 --- a/rest_framework_dataclasses/typing_utils.py +++ b/rest_framework_dataclasses/typing_utils.py @@ -56,8 +56,10 @@ def get_resolved_type_hints(tp: type) -> typing.Dict[str, type]: Resolving the type hints means converting any stringified type hint into an actual type object. These can come from either forward references (PEP 484), or postponed evaluation (PEP 563). """ - # typing.get_type_hints() does the heavy lifting for us, except when using PEP 585 generic types that contain a - # stringified type hint (see https://bugs.python.org/issue41370) + # typing.get_type_hints() does the heavy lifting for us, except: + # - when using PEP 585 generic types that contain a stringified type hint, on Python 3.9 and 3.10. See + # https://bugs.python.org/issue41370. Only references to objects in the global namespace are supported here. + # - when using PEP 695 type aliases def _resolve_type(context_type: type, resolve_type: typing.Union[str, type]) -> type: if isinstance(resolve_type, str): globalsns = sys.modules[context_type.__module__].__dict__ @@ -66,12 +68,14 @@ def _resolve_type(context_type: type, resolve_type: typing.Union[str, type]) -> return _resolve_type_hint(context_type, resolve_type) def _resolve_type_hint(context_type: type, resolve_type: type) -> type: - if not hasattr(types, 'GenericAlias') or not isinstance(resolve_type, types.GenericAlias): + if hasattr(types, 'GenericAlias') and isinstance(resolve_type, types.GenericAlias): + args = tuple(_resolve_type(context_type, arg) for arg in resolve_type.__args__) + return typing.cast(type, types.GenericAlias(resolve_type.__origin__, args)) + elif hasattr(typing, 'TypeAliasType') and isinstance(resolve_type, typing.TypeAliasType): + return _resolve_type_hint(context_type, resolve_type.__value__) + else: return resolve_type - args = tuple(_resolve_type(context_type, arg) for arg in resolve_type.__args__) - return typing.cast(type, types.GenericAlias(resolve_type.__origin__, args)) - return {k: _resolve_type_hint(tp, v) for k, v in typing.get_type_hints(tp).items()} @@ -284,7 +288,7 @@ def is_type_variable(tp: type) -> bool: return isinstance(tp, typing.TypeVar) -def get_variable_type_substitute(tp: type) -> type: +def get_type_variable_substitution(tp: type) -> type: """ Get the substitute for a variable type. """ diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..e20b04a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,21 @@ +import unittest +import sys + + +def load_tests(loader: unittest.TestLoader, tests, pattern): + # Manually load tests to avoid loading tests with syntax that's incompatible with the current Python version + for module in ( + 'test_field_generation', + 'test_field_utils', + 'test_fields', + 'test_functional', + 'test_issues', + 'test_serializers', + 'test_typing_utils', + ): + tests.addTests(loader.loadTestsFromName('tests.' + module)) + + if sys.version_info >= (3, 12, 0): + tests.addTests(loader.loadTestsFromName('tests.test_py312')) + + return tests diff --git a/tests/test_py312.py b/tests/test_py312.py new file mode 100644 index 0000000..57f7920 --- /dev/null +++ b/tests/test_py312.py @@ -0,0 +1,35 @@ +import typing +import unittest +import sys + +from rest_framework_dataclasses import typing_utils + + +@unittest.skipIf(sys.version_info < (3, 12, 0), 'Python 3.12 required') +class Python312Test(unittest.TestCase): + def test_resolve_pep695(self): + type Str = str + type StrList = list[str] + type GenericList[T] = list[T] + + class Hinted: + a: Str + b: StrList + c: GenericList + + hints = typing_utils.get_resolved_type_hints(Hinted) + self.assertEqual(hints['a'], str) + self.assertEqual(hints['b'], list[str]) + self.assertEqual(typing.get_origin(hints['c']), list) + + def test_typevar_pep695(self): + type GenericList[T: str] = list[T] + def fn() -> GenericList: + pass + + tp = typing_utils.get_resolved_type_hints(fn)['return'] + + self.assertTrue(typing_utils.is_iterable_type(tp)) + element_type = typing_utils.get_iterable_element_type(tp) + self.assertTrue(typing_utils.is_type_variable(element_type)) + self.assertEqual(typing_utils.get_type_variable_substitution(element_type), str) diff --git a/tests/test_typing_utils.py b/tests/test_typing_utils.py index e7669bf..3a9bc2f 100644 --- a/tests/test_typing_utils.py +++ b/tests/test_typing_utils.py @@ -1,3 +1,4 @@ +import types as types_module import typing import unittest import sys @@ -5,12 +6,46 @@ from rest_framework_dataclasses import types, typing_utils +class GlobalType: + pass + + class TypingTest(unittest.TestCase): def assertAnyTypeEquivalent(self, tp: type): # In some cases we accept either typing.Any (used by Python 3.9+) or an unconstrained typevar (used by Python # 3.7 and 3.8). It's essentially the same, and we strip the typevar before usage anyway. self.assertTrue(tp is typing.Any or (isinstance(tp, typing.TypeVar) and len(tp.__constraints__) == 0)) + def test_resolve(self): + class Hinted: + a: str + b: 'str' + + hints = typing_utils.get_resolved_type_hints(Hinted) + self.assertEqual(hints['a'], str) + self.assertEqual(hints['b'], str) + + @unittest.skipIf(sys.version_info < (3, 9, 0), 'Python 3.9 required') + def test_resolve_pep585(self): + # Pre-Python 3.11 only references to the global namespace are supported + class Hinted: + a: list[GlobalType] + b: list['GlobalType'] + + hints = typing_utils.get_resolved_type_hints(Hinted) + self.assertEqual(hints['a'], types_module.GenericAlias(list, (GlobalType, ))) + self.assertEqual(hints['b'], types_module.GenericAlias(list, (GlobalType, ))) + + @unittest.skipIf(sys.version_info < (3, 11, 0), 'Python 3.11 required') + def test_resolve_pep585_full(self): + class Hinted: + a: list[str] + b: list['str'] + + hints = typing_utils.get_resolved_type_hints(Hinted) + self.assertEqual(hints['a'], types_module.GenericAlias(list, (str, ))) + self.assertEqual(hints['b'], types_module.GenericAlias(list, (str, ))) + def test_iterable(self): self.assertTrue(typing_utils.is_iterable_type(typing.Iterable[str])) self.assertTrue(typing_utils.is_iterable_type(typing.Collection[str])) @@ -169,6 +204,6 @@ def test_variable_type(self): self.assertFalse(typing_utils.is_type_variable(int)) self.assertFalse(typing_utils.is_type_variable(typing.List)) - self.assertEqual(typing_utils.get_variable_type_substitute(T), typing.Any) - self.assertEqual(typing_utils.get_variable_type_substitute(U), typing.Union[int, str]) - self.assertEqual(typing_utils.get_variable_type_substitute(V), Exception) + self.assertEqual(typing_utils.get_type_variable_substitution(T), typing.Any) + self.assertEqual(typing_utils.get_type_variable_substitution(U), typing.Union[int, str]) + self.assertEqual(typing_utils.get_type_variable_substitution(V), Exception)