diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e0773fa1..3aef2827 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,9 @@ Fixed - Callable type with subclass return not showing the ``--*.help`` option (`#567 `__). +- Forward referenced types not compatible with `Type` typehint (`#576 + `__) + Changed ^^^^^^^ - Removed shtab experimental warning (`#561 diff --git a/jsonargparse/_postponed_annotations.py b/jsonargparse/_postponed_annotations.py index d1a5e112..5dbdffd4 100644 --- a/jsonargparse/_postponed_annotations.py +++ b/jsonargparse/_postponed_annotations.py @@ -229,6 +229,8 @@ def resolve_subtypes_forward_refs(typehint): typehint_origin = Tuple elif typehint_origin in mapping_origin_types: typehint_origin = Dict + elif typehint_origin == type: + typehint_origin = Type typehint = typehint_origin[tuple(subtypes)] except Exception as ex: if logger: @@ -240,6 +242,9 @@ def resolve_subtypes_forward_refs(typehint): def has_subtypes(typehint): typehint_origin = get_typehint_origin(typehint) + if typehint_origin is type and hasattr(typehint, "__args__"): + return True + return ( typehint_origin == Union or typehint_origin in sequence_origin_types @@ -260,7 +265,6 @@ def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict: types = get_type_hints(obj, global_vars) except Exception as ex1: types = ex1 # type: ignore[assignment] - if isinstance(types, dict) and all(not type_requires_eval(t) for t in types.values()): return types diff --git a/jsonargparse_tests/test_postponed_annotations.py b/jsonargparse_tests/test_postponed_annotations.py index 0d6b6baf..05d7d907 100644 --- a/jsonargparse_tests/test_postponed_annotations.py +++ b/jsonargparse_tests/test_postponed_annotations.py @@ -10,7 +10,11 @@ from jsonargparse import Namespace from jsonargparse._parameter_resolvers import get_signature_parameters as get_params -from jsonargparse._postponed_annotations import TypeCheckingVisitor, evaluate_postponed_annotations, get_types +from jsonargparse._postponed_annotations import ( + TypeCheckingVisitor, + evaluate_postponed_annotations, + get_types, +) from jsonargparse.typing import Path_drw from jsonargparse_tests.conftest import capture_logs, source_unavailable @@ -267,6 +271,17 @@ def test_get_types_type_checking_tuple(): assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]" +def function_type_checking_type(p1: Type["TypeCheckingClass2"]): + return p1 + + +def test_get_types_type_checking_type(): + types = get_types(function_type_checking_type) + assert list(types.keys()) == ["p1"] + tpl = "typing.Type" if sys.version_info < (3, 10) else "type" + assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass2]" + + def function_type_checking_dict(p1: Dict[str, Union[TypeCheckingClass1, "TypeCheckingClass2"]]): return p1