Skip to content

Commit

Permalink
Handle forward references in Type typehints (#576)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Mauricio Villegas <[email protected]>
  • Loading branch information
EthanMarx and mauvilsa authored Sep 13, 2024
1 parent 1bab38a commit d539f42
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ Fixed
- Callable type with subclass return not showing the ``--*.help`` option (`#567
<https://github.com/omni-us/jsonargparse/pull/567>`__).

- Forward referenced types not compatible with `Type` typehint (`#576
<https://github.com/omni-us/jsonargparse/pull/576/>`__)

Changed
^^^^^^^
- Removed shtab experimental warning (`#561
Expand Down
6 changes: 5 additions & 1 deletion jsonargparse/_postponed_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def resolve_subtypes_forward_refs(typehint):
typehint_origin = Tuple
elif typehint_origin in mapping_origin_types:
typehint_origin = Dict
elif typehint_origin == type:
typehint_origin = Type
typehint = typehint_origin[tuple(subtypes)]
except Exception as ex:
if logger:
Expand All @@ -240,6 +242,9 @@ def resolve_subtypes_forward_refs(typehint):

def has_subtypes(typehint):
typehint_origin = get_typehint_origin(typehint)
if typehint_origin is type and hasattr(typehint, "__args__"):
return True

return (
typehint_origin == Union
or typehint_origin in sequence_origin_types
Expand All @@ -260,7 +265,6 @@ def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
types = get_type_hints(obj, global_vars)
except Exception as ex1:
types = ex1 # type: ignore[assignment]

if isinstance(types, dict) and all(not type_requires_eval(t) for t in types.values()):
return types

Expand Down
17 changes: 16 additions & 1 deletion jsonargparse_tests/test_postponed_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from jsonargparse import Namespace
from jsonargparse._parameter_resolvers import get_signature_parameters as get_params
from jsonargparse._postponed_annotations import TypeCheckingVisitor, evaluate_postponed_annotations, get_types
from jsonargparse._postponed_annotations import (
TypeCheckingVisitor,
evaluate_postponed_annotations,
get_types,
)
from jsonargparse.typing import Path_drw
from jsonargparse_tests.conftest import capture_logs, source_unavailable

Expand Down Expand Up @@ -267,6 +271,17 @@ def test_get_types_type_checking_tuple():
assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]"


def function_type_checking_type(p1: Type["TypeCheckingClass2"]):
return p1


def test_get_types_type_checking_type():
types = get_types(function_type_checking_type)
assert list(types.keys()) == ["p1"]
tpl = "typing.Type" if sys.version_info < (3, 10) else "type"
assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass2]"


def function_type_checking_dict(p1: Dict[str, Union[TypeCheckingClass1, "TypeCheckingClass2"]]):
return p1

Expand Down

0 comments on commit d539f42

Please sign in to comment.