From 9d580148104809a272c7cfd43e151b29cacf0b46 Mon Sep 17 00:00:00 2001
From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com>
Date: Tue, 6 Jun 2023 06:57:02 +0200
Subject: [PATCH] - Added support for postponed evaluation of annotations PEP
563 (#120). - Backport types in python<=3.9 to support PEP 585 and 604 for
postponed evaluation of annotations (#120).
---
.circleci/config.yml | 47 ++++--
.pre-commit-config.yaml | 36 +++++
CHANGELOG.rst | 14 ++
README.rst | 7 +
jsonargparse/_backports.py | 138 ++++++++++++++++--
jsonargparse/_stubs_resolver.py | 64 +-------
jsonargparse/link_arguments.py | 4 +-
jsonargparse/parameter_resolvers.py | 3 +
jsonargparse/signatures.py | 24 ++-
jsonargparse/util.py | 12 +-
jsonargparse_tests/test_actions.py | 2 +
jsonargparse_tests/test_argcomplete.py | 2 +
jsonargparse_tests/test_backports.py | 105 +++++++++++++
jsonargparse_tests/test_cli.py | 2 +
jsonargparse_tests/test_core.py | 2 +
jsonargparse_tests/test_dataclass_like.py | 33 +++--
jsonargparse_tests/test_deprecated.py | 2 +
jsonargparse_tests/test_formatters.py | 2 +
jsonargparse_tests/test_jsonnet.py | 2 +
jsonargparse_tests/test_jsonschema.py | 2 +
jsonargparse_tests/test_link_arguments.py | 2 +
jsonargparse_tests/test_loaders_dumpers.py | 2 +
jsonargparse_tests/test_namespace.py | 2 +
jsonargparse_tests/test_optionals.py | 2 +
.../test_parameter_resolvers.py | 18 ++-
jsonargparse_tests/test_paths.py | 2 +
jsonargparse_tests/test_signatures.py | 28 ++--
jsonargparse_tests/test_stubs_resolver.py | 4 +-
jsonargparse_tests/test_subclasses.py | 2 +
jsonargparse_tests/test_subcommands.py | 2 +
jsonargparse_tests/test_typehints.py | 8 +
jsonargparse_tests/test_typing.py | 2 +
jsonargparse_tests/test_util.py | 6 +-
pyproject.toml | 2 +-
34 files changed, 448 insertions(+), 137 deletions(-)
create mode 100644 jsonargparse_tests/test_backports.py
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 468ba389..e8b0d7f6 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -32,6 +32,8 @@ jobs:
python3 -m jsonargparse_tests coverage xml coverage_py$py.xml
pip3 install $(ls ./dist/*.whl)[test,all]
python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml
+ sed -i '/^from __future__ import annotations$/d' jsonargparse_tests/test_*.py
+ python3 -m jsonargparse_tests coverage xml coverage_py${py}_types.xml
- persist_to_workspace:
root: .
paths:
@@ -53,9 +55,26 @@ jobs:
docker:
- image: cimg/python:3.7
test-py36:
- <<: *test-py38
docker:
- image: cimg/python:3.6
+ steps:
+ - attach_workspace:
+ at: .
+ - run:
+ name: Run unit tests
+ command: |
+ py=$(python3 --version | sed -r 's|.* 3\.([0-9]+)\..*|3.\1|')
+ sed -i '/^from __future__ import annotations$/d' jsonargparse_tests/test_*.py
+ virtualenv -p python3 venv$py
+ . venv$py/bin/activate
+ pip3 install $(ls ./dist/*.whl)[test-no-urls]
+ python3 -m jsonargparse_tests coverage xml coverage_py$py.xml
+ pip3 install $(ls ./dist/*.whl)[test,all]
+ python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml
+ - persist_to_workspace:
+ root: .
+ paths:
+ - ./coverage_*.xml
codecov:
docker:
- image: cimg/python:3.8
@@ -66,28 +85,26 @@ jobs:
- run:
name: Code coverage
command: |
- #for py in 3.6 3.7 3.8 3.9 3.10 3.11; do
- for py in 3.6 3.7 3.8 3.9 3.10; do
- bash <(curl -s https://codecov.io/bash) \
- -Z \
- -t $CODECOV_TOKEN_JSONARGPARSE \
- -F py$py \
- -f coverage_py${py}.xml
- bash <(curl -s https://codecov.io/bash) \
- -Z \
- -t $CODECOV_TOKEN_JSONARGPARSE \
- -F py${py}_all \
- -f coverage_py${py}_all.xml
+ for py in 3.6 3.7 3.8 3.9 3.10 3.11; do
+ for suffix in "" "_all" "_types"; do
+ bash <(curl -s https://codecov.io/bash) \
+ -Z \
+ -t $CODECOV_TOKEN_JSONARGPARSE \
+ -F py${py}${suffix} \
+ -f coverage_py${py}${suffix}.xml
+ done
done
publish-pypi:
docker:
- - image: mauvilsa/docker-twine:1.11.0
+ - image: cimg/python:3.10
steps:
- attach_workspace:
at: .
- run:
name: Publish Release on PyPI
- command: twine upload --username __token__ --password "${PYPI_TOKEN}" ./dist/*.whl ./dist/*.tar.gz
+ command: |
+ pip3 install -U twine
+ twine upload --username __token__ --password "${PYPI_TOKEN}" ./dist/*.whl ./dist/*.tar.gz
workflows:
version: 2
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 72a2c8fe..144d196a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -103,6 +103,42 @@ repos:
pass_filenames: false
verbose: true
+ - id: test-py36
+ name: test-py36
+ entry: bash -c '
+ if [ "$(which python3.6)" = "" ]; then
+ echo "$(tput setaf 6) Skipped, python 3.6 not found $(tput sgr0)";
+ else
+ TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX);
+ cleanup () { rm -rf "$TEST_DIR"; };
+ trap cleanup EXIT;
+ ./setup.py bdist_wheel;
+ python3.6 -m venv "$TEST_DIR/venv36";
+ . "$TEST_DIR/venv36/bin/activate";
+ pip3 install "$(ls ./dist/*.whl | tail -n 1)[all,test,test-no-urls]";
+ pip3 install pytest pytest-subtests;
+ rsync -a jsonargparse_tests/*.py "$TEST_DIR";
+ cd "$TEST_DIR";
+ rm test_backports.py;
+ sed -i "/^from __future__ import annotations$/d" *.py;
+ pytest;
+ fi'
+ language: system
+ pass_filenames: false
+
+ - id: test-without-future-annotations
+ name: test-without-future-annotations
+ entry: bash -c '
+ TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX);
+ cleanup () { rm -rf "$TEST_DIR"; };
+ trap cleanup EXIT;
+ rsync -a jsonargparse_tests/*.py "$TEST_DIR";
+ cd "$TEST_DIR";
+ sed -i "/^from __future__ import annotations$/d" *.py;
+ pytest $TEST_DIR;'
+ language: system
+ pass_filenames: false
+
- id: doctest
name: sphinx-build -M doctest sphinx sphinx/_build sphinx/index.rst
entry: bash -c '
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 666763b4..3b290d64 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -12,6 +12,20 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.
+v4.22.0 (2023-06-??)
+--------------------
+
+Added
+^^^^^
+- Support for postponed evaluation of annotations PEP `563
+ `__ ``from __future__ import annotations``
+ (`#120 `__).
+- Backport types in python<=3.9 to support PEP `585
+ `__ and `604
+ `__ for postponed evaluation of annotations
+ (`#120 `__).
+
+
v4.21.2 (2023-06-??)
--------------------
diff --git a/README.rst b/README.rst
index fae71275..46dd8cd1 100644
--- a/README.rst
+++ b/README.rst
@@ -425,6 +425,13 @@ Some notes about this support are:
nesting it is meant child types inside ``List``, ``Dict``, etc. There is no
limit in nesting depth.
+- Postponed evaluation of types PEP `563 `__
+ (i.e. ``from __future__ import annotations``) is supported. Also supported on
+ ``python<=3.9`` are PEP `585 `__ (i.e.
+ ``list[], dict[], ...`` instead of ``List[], Dict[],
+ ...``) and `604 `__ (i.e. `` |
+ `` instead of ``Union[, ]``).
+
- Fully supported types are: ``str``, ``bool`` (more details in
:ref:`boolean-arguments`), ``int``, ``float``, ``complex``,
``bytes``/``bytearray`` (Base64 encoding), ``List`` (more details in
diff --git a/jsonargparse/_backports.py b/jsonargparse/_backports.py
index 0e952ddf..94a329f6 100644
--- a/jsonargparse/_backports.py
+++ b/jsonargparse/_backports.py
@@ -1,24 +1,28 @@
import ast
+import inspect
+import sys
+import textwrap
+import typing as T
from collections import namedtuple
from copy import deepcopy
-from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union
+
+from .optionals import typing_extensions_import
+from .util import unique
var_map = namedtuple("var_map", "name value")
none_map = var_map(name="NoneType", value=type(None))
-union_map = var_map(name="Union", value=Union)
+union_map = var_map(name="Union", value=T.Union)
pep585_map = {
- "dict": var_map(name="Dict", value=Dict),
- "frozenset": var_map(name="FrozenSet", value=FrozenSet),
- "list": var_map(name="List", value=List),
- "set": var_map(name="Set", value=Set),
- "tuple": var_map(name="Tuple", value=Tuple),
- "type": var_map(name="Type", value=Type),
+ "dict": var_map(name="Dict", value=T.Dict),
+ "frozenset": var_map(name="FrozenSet", value=T.FrozenSet),
+ "list": var_map(name="List", value=T.List),
+ "set": var_map(name="Set", value=T.Set),
+ "tuple": var_map(name="Tuple", value=T.Tuple),
+ "type": var_map(name="Type", value=T.Type),
}
class BackportTypeHints(ast.NodeTransformer):
- _typing = __import__("typing")
-
def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
if isinstance(node.value, ast.Name) and node.value.id in pep585_map:
value = self.new_name_load(pep585_map[node.value.id])
@@ -30,13 +34,13 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
ctx=ast.Load(),
)
- def visit_Constant(self, node: ast.Constant) -> Union[ast.Constant, ast.Name]:
+ def visit_Constant(self, node: ast.Constant) -> T.Union[ast.Constant, ast.Name]:
if node.value is None:
return self.new_name_load(none_map)
return node
- def visit_BinOp(self, node: ast.BinOp) -> Union[ast.BinOp, ast.Subscript]:
- out_node: Union[ast.BinOp, ast.Subscript] = node
+ def visit_BinOp(self, node: ast.BinOp) -> T.Union[ast.BinOp, ast.Subscript]:
+ out_node: T.Union[ast.BinOp, ast.Subscript] = node
if isinstance(node.op, ast.BitOr):
elts: list = []
self.append_union_elts(node.left, elts)
@@ -66,8 +70,112 @@ def new_name_load(self, var: var_map) -> ast.Name:
def backport(self, input_ast: ast.AST, exec_vars: dict) -> ast.AST:
for key, value in exec_vars.items():
if getattr(value, "__module__", "") == "collections.abc":
- if hasattr(self._typing, key):
- exec_vars[key] = getattr(self._typing, key)
+ if hasattr(T, key):
+ exec_vars[key] = getattr(T, key)
self.exec_vars = exec_vars
backport_ast = self.visit(deepcopy(input_ast))
return ast.fix_missing_locations(backport_ast)
+
+
+class NamesVisitor(ast.NodeVisitor):
+ def visit_Name(self, node: ast.Name) -> None:
+ self.names_found.append(node.id)
+
+ def find(self, node: ast.AST) -> list:
+ self.names_found: T.List[str] = []
+ self.visit(node)
+ self.names_found = unique(self.names_found)
+ return self.names_found
+
+
+def get_arg_type(arg_ast, aliases):
+ type_ast = ast.parse("___arg_type___ = 0")
+ type_ast.body[0].value = arg_ast.annotation
+ exec_vars = {}
+ bad_aliases = {}
+ add_asts = False
+ for name in NamesVisitor().find(arg_ast.annotation):
+ value = aliases[name]
+ if isinstance(value, tuple):
+ value = value[1]
+ if isinstance(value, Exception):
+ bad_aliases[name] = value
+ elif isinstance(value, ast.AST):
+ add_asts = True
+ else:
+ exec_vars[name] = value
+ if add_asts:
+ body = []
+ for name, (_, value) in aliases.items():
+ if isinstance(value, ast.AST):
+ body.append(ast.fix_missing_locations(value))
+ elif not isinstance(value, Exception):
+ exec_vars[name] = value
+ type_ast.body = body + type_ast.body
+ if "TypeAlias" not in exec_vars:
+ type_alias = typing_extensions_import("TypeAlias")
+ if type_alias:
+ exec_vars["TypeAlias"] = type_alias
+ if sys.version_info < (3, 10):
+ backporter = BackportTypeHints()
+ type_ast = backporter.backport(type_ast, exec_vars)
+ try:
+ exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars)
+ except NameError as ex:
+ ex_from = None
+ for name, alias_exception in bad_aliases.items():
+ if str(ex) == f"name '{name}' is not defined":
+ ex_from = alias_exception
+ break
+ raise ex from ex_from
+ return exec_vars["___arg_type___"]
+
+
+def get_type_hints(obj: T.Any, globalns: T.Optional[dict] = None, localns: T.Optional[dict] = None) -> dict:
+ try:
+ return T.get_type_hints(obj, globalns, localns)
+ except Exception as ex1:
+ try:
+ source = textwrap.dedent(inspect.getsource(obj))
+ tree = ast.parse(source)
+ assert isinstance(tree, ast.Module) and len(tree.body) == 1
+ node = tree.body[0]
+
+ aliases = __builtins__.copy() # type: ignore
+ if globalns is None:
+ globalns = obj.__globals__
+ aliases.update(globalns)
+ if localns is not None:
+ aliases.update(localns)
+
+ types = {}
+ for arg_ast in node.args.args + node.args.kwonlyargs: # type: ignore
+ if not arg_ast.annotation:
+ continue
+ name = arg_ast.arg
+ types[name] = get_arg_type(arg_ast, aliases)
+
+ return types
+ except Exception as ex2:
+ raise type(ex1)(f"{repr(ex1)} + {repr(ex2)}") from ex2
+
+
+def evaluate_postponed_annotations(params, component, logger):
+ if sys.version_info[:2] == (3, 6) or not (
+ params and any(isinstance(p.annotation, (str, T.ForwardRef)) for p in params)
+ ):
+ return
+ try:
+ if sys.version_info < (3, 10):
+ types = get_type_hints(component)
+ else:
+ types = T.get_type_hints(component)
+ except Exception as ex:
+ logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
+ return
+ for param in params:
+ if param.name in types:
+ param_type = types[param.name]
+ if isinstance(param_type, T.ForwardRef):
+ param_type = param_type._evaluate(component.__globals__, {})
+ param.annotation = param_type
diff --git a/jsonargparse/_stubs_resolver.py b/jsonargparse/_stubs_resolver.py
index 164ae05a..01f305ed 100644
--- a/jsonargparse/_stubs_resolver.py
+++ b/jsonargparse/_stubs_resolver.py
@@ -5,14 +5,10 @@
from copy import deepcopy
from importlib import import_module
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
-from .optionals import (
- import_typeshed_client,
- typeshed_client_support,
- typing_extensions_import,
-)
-from .util import unique
+from ._backports import NamesVisitor, get_arg_type
+from .optionals import import_typeshed_client, typeshed_client_support
if TYPE_CHECKING: # pragma: no cover
import typeshed_client as tc
@@ -32,17 +28,6 @@ def import_module_or_none(path: str):
return None
-class NamesVisitor(ast.NodeVisitor):
- def visit_Name(self, node: ast.Name) -> None:
- self.names_found.append(node.id)
-
- def find(self, node: ast.AST) -> list:
- self.names_found: List[str] = []
- self.visit(node)
- self.names_found = unique(self.names_found)
- return self.names_found
-
-
class ImportsVisitor(ast.NodeVisitor):
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.level:
@@ -262,49 +247,6 @@ def alias_is_unique(aliases, name, source, value):
return True
-def get_arg_type(arg_ast, aliases):
- type_ast = ast.parse("___arg_type___ = 0")
- type_ast.body[0].value = arg_ast.annotation
- exec_vars = {}
- bad_aliases = {}
- add_asts = False
- for name in NamesVisitor().find(arg_ast.annotation):
- _, value = aliases[name]
- if isinstance(value, Exception):
- bad_aliases[name] = value
- elif isinstance(value, ast.AST):
- add_asts = True
- else:
- exec_vars[name] = value
- if add_asts:
- body = []
- for name, (_, value) in aliases.items():
- if isinstance(value, ast.AST):
- body.append(ast.fix_missing_locations(value))
- elif not isinstance(value, Exception):
- exec_vars[name] = value
- type_ast.body = body + type_ast.body
- if "TypeAlias" not in exec_vars:
- type_alias = typing_extensions_import("TypeAlias")
- if type_alias:
- exec_vars["TypeAlias"] = type_alias
- if sys.version_info < (3, 10):
- from ._backports import BackportTypeHints
-
- backporter = BackportTypeHints()
- type_ast = backporter.backport(type_ast, exec_vars)
- try:
- exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars)
- except NameError as ex:
- ex_from = None
- for name, alias_exception in bad_aliases.items():
- if str(ex) == f"name '{name}' is not defined":
- ex_from = alias_exception
- break
- raise ex from ex_from
- return exec_vars["___arg_type___"]
-
-
def get_stub_types(params, component, parent, logger) -> Optional[Dict[str, Any]]:
if not typeshed_client_support:
return None
diff --git a/jsonargparse/link_arguments.py b/jsonargparse/link_arguments.py
index a7ee315f..9bd78fdf 100644
--- a/jsonargparse/link_arguments.py
+++ b/jsonargparse/link_arguments.py
@@ -1,6 +1,5 @@
"""Code related to argument linking."""
-import inspect
import re
from argparse import SUPPRESS
from argparse import Action as ArgparseAction
@@ -18,6 +17,7 @@
filter_default_actions,
)
from .namespace import Namespace, split_key_leaf
+from .parameter_resolvers import get_signature_parameters
from .type_checking import ArgumentParser, _ArgumentGroup
__all__ = ["ArgumentLinking"]
@@ -271,7 +271,7 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None:
value = value.as_dict()
else:
# Automatic namespace to dict based on compute_fn param type hint
- params = list(inspect.signature(action.compute_fn).parameters.values())
+ params = get_signature_parameters(action.compute_fn)
for n, param in enumerate(params):
if (
n < len(args)
diff --git a/jsonargparse/parameter_resolvers.py b/jsonargparse/parameter_resolvers.py
index e2b286ff..a4e8614f 100644
--- a/jsonargparse/parameter_resolvers.py
+++ b/jsonargparse/parameter_resolvers.py
@@ -10,6 +10,7 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+from ._backports import evaluate_postponed_annotations
from ._common import is_dataclass_like, is_subclass
from ._stubs_resolver import get_stub_types
from .optionals import parse_docs
@@ -275,6 +276,7 @@ def get_signature_parameters_and_indexes(component, parent, logger):
component=component,
**{a: getattr(param, a) for a in parameter_attributes},
)
+ evaluate_postponed_annotations(params, signature_source, logger)
stubs = get_stub_types(params, signature_source, parent, logger)
return params, args_idx, kwargs_idx, doc_params, stubs
@@ -801,6 +803,7 @@ def get_parameters_from_pydantic(
component=function_or_class,
)
)
+ evaluate_postponed_annotations(params, function_or_class, logger)
return params
diff --git a/jsonargparse/signatures.py b/jsonargparse/signatures.py
index 2b5670fe..7c9aa070 100644
--- a/jsonargparse/signatures.py
+++ b/jsonargparse/signatures.py
@@ -15,7 +15,12 @@
get_parameter_origins,
get_signature_parameters,
)
-from .typehints import ActionTypeHint, LazyInitBaseClass, is_optional
+from .typehints import (
+ ActionTypeHint,
+ LazyInitBaseClass,
+ callable_instances,
+ is_optional,
+)
from .typing import register_pydantic_type
from .util import LoggerProperty, get_import_path, iter_to_set_str
@@ -176,9 +181,14 @@ def add_function_arguments(
if not callable(function):
raise ValueError('Expected "function" argument to be a callable object.')
+ method_name = None
+ if hasattr(function, "__class__") and callable_instances(function.__class__):
+ function = function.__class__
+ method_name = "__call__"
+
return self._add_signature_arguments(
function,
- None,
+ method_name,
nested_key,
as_group,
as_positional,
@@ -366,12 +376,10 @@ def _add_signature_parameter(
action.sub_add_kwargs["skip"] = subclass_skip
added_args.append(dest)
elif is_required and fail_untyped:
- msg = f'With fail_untyped=True, all mandatory parameters must have a supported type. Parameter "{name}" from "{src}" '
- if isinstance(annotation, str):
- msg += "specifies the type as a string. Types as a string and `from __future__ import annotations` is currently not supported."
- else:
- msg += "does not specify a type."
- raise ValueError(msg)
+ raise ValueError(
+ "With fail_untyped=True, all mandatory parameters must have a supported"
+ f" type. Parameter '{name}' from '{src}' does not specify a type."
+ )
def add_dataclass_arguments(
self,
diff --git a/jsonargparse/util.py b/jsonargparse/util.py
index 382f3db9..86e72780 100644
--- a/jsonargparse/util.py
+++ b/jsonargparse/util.py
@@ -28,6 +28,7 @@
Type,
TypeVar,
Union,
+ get_type_hints,
)
from ._common import parser_capture, parser_context
@@ -369,10 +370,13 @@ def class_from_function(func: Callable[..., ClassType]) -> Type[ClassType]:
"""
func_return = inspect.signature(func).return_annotation
if isinstance(func_return, str):
- caller_frame = inspect.currentframe().f_back # type: ignore
- func_return = caller_frame.f_locals.get(func_return) or caller_frame.f_globals.get(func_return) # type: ignore
- if func_return is None:
- raise ValueError(f"Unable to dereference {func_return} the return type of {func}.")
+ try:
+ func_return = get_type_hints(func)["return"]
+ if sys.version_info[:2] != (3, 6) and isinstance(func_return, __import__("typing").ForwardRef):
+ func_return = func_return._evaluate(func.__globals__, {})
+ except Exception as ex:
+ func_return = inspect.signature(func).return_annotation
+ raise ValueError(f"Unable to dereference {func_return}, the return type of {func}: {ex}") from ex
@wraps(func)
def __new__(cls, *args, **kwargs):
diff --git a/jsonargparse_tests/test_actions.py b/jsonargparse_tests/test_actions.py
index e1a26ea2..326f9d6f 100644
--- a/jsonargparse_tests/test_actions.py
+++ b/jsonargparse_tests/test_actions.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from pathlib import Path
import pytest
diff --git a/jsonargparse_tests/test_argcomplete.py b/jsonargparse_tests/test_argcomplete.py
index 68915888..c6fbd545 100644
--- a/jsonargparse_tests/test_argcomplete.py
+++ b/jsonargparse_tests/test_argcomplete.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import sys
from contextlib import ExitStack, contextmanager
diff --git a/jsonargparse_tests/test_backports.py b/jsonargparse_tests/test_backports.py
new file mode 100644
index 00000000..7c256387
--- /dev/null
+++ b/jsonargparse_tests/test_backports.py
@@ -0,0 +1,105 @@
+from __future__ import annotations # keep
+
+import sys
+from random import Random
+from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union
+
+import pytest
+
+from jsonargparse._backports import get_type_hints
+from jsonargparse.parameter_resolvers import get_signature_parameters as get_params
+
+
+@pytest.fixture(autouse=True)
+def skip_if_python_older_than_3_10():
+ if sys.version_info >= (3, 10, 0):
+ pytest.skip("python<3.10 is required")
+
+
+def function_pep585_dict(p1: dict[str, int], p2: dict[int, str] = {1: "a"}):
+ return p1
+
+
+def function_pep585_list(p1: list[str], p2: list[float] = [0.1, 2.3]):
+ return p1
+
+
+def function_pep585_set(p1: set[str], p2: set[int] = {1, 2}):
+ return p1
+
+
+def function_pep585_frozenset(p1: frozenset[str], p2: frozenset[int] = frozenset(range(3))):
+ return p1
+
+
+def function_pep585_tuple(p1: tuple[str, float], p2: tuple[int, ...] = (1, 2)):
+ return p1
+
+
+def function_pep585_type(p1: type[Random], p2: type[Random] = Random):
+ return p1
+
+
+@pytest.mark.skipif(sys.version_info >= (3, 9, 0), reason="python<3.9 is required")
+@pytest.mark.parametrize(
+ ["function", "expected"],
+ [
+ (function_pep585_dict, {"p1": Dict[str, int], "p2": Dict[int, str]}),
+ (function_pep585_list, {"p1": List[str], "p2": List[float]}),
+ (function_pep585_set, {"p1": Set[str], "p2": Set[int]}),
+ (function_pep585_frozenset, {"p1": FrozenSet[str], "p2": FrozenSet[int]}),
+ (function_pep585_tuple, {"p1": Tuple[str, float], "p2": Tuple[int, ...]}),
+ (function_pep585_type, {"p1": Type[Random], "p2": Type[Random]}),
+ ],
+)
+def test_pep585(function, expected):
+ types = get_type_hints(function)
+ assert types == expected
+
+
+def function_pep604(p1: str | None, p2: int | float | bool = 1):
+ return p1
+
+
+def test_pep604():
+ types = get_type_hints(function_pep604, function_pep604.__globals__, {})
+ assert types == {"p1": Union[str, None], "p2": Union[int, float, bool]}
+
+
+class NeedsBackport:
+ def __init__(self, p1: list | set):
+ self.p1 = p1
+
+ @staticmethod
+ def static_method(p1: str | int):
+ return p1
+
+ @classmethod
+ def class_method(cls, p1: float | None):
+ return p1
+
+
+@pytest.mark.parametrize(
+ ["method", "expected"],
+ [
+ (NeedsBackport.__init__, {"p1": Union[list, set]}),
+ (NeedsBackport.static_method, {"p1": Union[str, int]}),
+ (NeedsBackport.class_method, {"p1": Union[float, None]}),
+ ],
+)
+def test_methods(method, expected):
+ types = get_type_hints(method)
+ assert types == expected
+
+
+def function_undefined_type(p1: not_defined | None): # type: ignore # noqa: F821
+ return p1
+
+
+def test_undefined_type():
+ with pytest.raises(Exception) as ctx:
+ get_type_hints(function_undefined_type)
+ ctx.match("name 'not_defined' is not defined")
+
+ params = get_params(function_undefined_type)
+ assert params[0].annotation == "not_defined | None"
diff --git a/jsonargparse_tests/test_cli.py b/jsonargparse_tests/test_cli.py
index 48db4bc0..6927cd54 100644
--- a/jsonargparse_tests/test_cli.py
+++ b/jsonargparse_tests/test_cli.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import sys
from contextlib import redirect_stderr, redirect_stdout, suppress
diff --git a/jsonargparse_tests/test_core.py b/jsonargparse_tests/test_core.py
index 0b3733be..4b49cc8d 100644
--- a/jsonargparse_tests/test_core.py
+++ b/jsonargparse_tests/test_core.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import os
import pickle
diff --git a/jsonargparse_tests/test_dataclass_like.py b/jsonargparse_tests/test_dataclass_like.py
index c5182c97..f7e346ef 100644
--- a/jsonargparse_tests/test_dataclass_like.py
+++ b/jsonargparse_tests/test_dataclass_like.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import dataclasses
import sys
from typing import Any, Dict, List, Optional, Union
@@ -418,7 +420,7 @@ def none(x):
return x
-@pytest.mark.skipif(sys.version_info == (3, 6), reason="pydantic not supported in python 3.6")
+@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="pydantic not supported in python 3.6")
@pytest.mark.skipif(not pydantic_support, reason="pydantic package is required")
class TestPydantic:
def test_dataclass(self, parser):
@@ -448,21 +450,22 @@ def test_field_description(self, parser):
assert "p1 help (required, type: str)" in help_str
assert "p2 help (type: int, default: 2)" in help_str
- def test_pydantic_types(self, subtests):
- for valid_value, invalid_value, cast, pydantic_type in [
- ("abc", "a", none, pydantic.constr(min_length=2, max_length=4)),
- (2, 0, none, pydantic.conint(ge=1)),
- (-1.0, 1.0, none, pydantic.confloat(lt=0.0)),
- ([1], [], none, pydantic.conlist(int, min_items=1)),
- ([], [3, 4], none, pydantic.conlist(int, max_items=1)),
- ([1], "x", list, pydantic.conset(int, min_items=1)),
- ("http://abc.es", "-", none, pydantic.HttpUrl),
- ("127.0.0.1", "0", str, pydantic.IPvAnyAddress),
- ]:
+ def test_pydantic_types(self, subtests, monkeypatch):
+ for num, (valid_value, invalid_value, cast, pydantic_type) in enumerate(
+ [
+ ("abc", "a", none, pydantic.constr(min_length=2, max_length=4)),
+ (2, 0, none, pydantic.conint(ge=1)),
+ (-1.0, 1.0, none, pydantic.confloat(lt=0.0)),
+ ([1], [], none, pydantic.conlist(int, min_items=1)),
+ ([], [3, 4], none, pydantic.conlist(int, max_items=1)),
+ ([1], "x", list, pydantic.conset(int, min_items=1)),
+ ("http://abc.es", "-", none, pydantic.HttpUrl),
+ ("127.0.0.1", "0", str, pydantic.IPvAnyAddress),
+ ]
+ ):
with subtests.test(f"type={pydantic_type.__name__} valid={valid_value} invalid={invalid_value}"):
-
- class Model(pydantic.BaseModel):
- param: pydantic_type
+ Model = pydantic.create_model(f"Model{num}", param=(pydantic_type, ...))
+ monkeypatch.setitem(Model.__init__.__globals__, "pydantic_type", pydantic_type)
parser = ArgumentParser(exit_on_error=False)
parser.add_argument("--model", type=Model)
diff --git a/jsonargparse_tests/test_deprecated.py b/jsonargparse_tests/test_deprecated.py
index 9995111d..c8019cfc 100644
--- a/jsonargparse_tests/test_deprecated.py
+++ b/jsonargparse_tests/test_deprecated.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import pathlib
from calendar import Calendar
diff --git a/jsonargparse_tests/test_formatters.py b/jsonargparse_tests/test_formatters.py
index 01aa2f63..f92bb3f8 100644
--- a/jsonargparse_tests/test_formatters.py
+++ b/jsonargparse_tests/test_formatters.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from pathlib import Path
from typing import Tuple
diff --git a/jsonargparse_tests/test_jsonnet.py b/jsonargparse_tests/test_jsonnet.py
index fec2949c..9fcf9b7a 100644
--- a/jsonargparse_tests/test_jsonnet.py
+++ b/jsonargparse_tests/test_jsonnet.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import re
from pathlib import Path
diff --git a/jsonargparse_tests/test_jsonschema.py b/jsonargparse_tests/test_jsonschema.py
index 59e121dd..7af247a5 100644
--- a/jsonargparse_tests/test_jsonschema.py
+++ b/jsonargparse_tests/test_jsonschema.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import re
from importlib.util import find_spec
diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py
index 6513b4c3..c2f4c31e 100644
--- a/jsonargparse_tests/test_link_arguments.py
+++ b/jsonargparse_tests/test_link_arguments.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from calendar import Calendar, TextCalendar
from typing import Any, List, Mapping, Optional, Union
diff --git a/jsonargparse_tests/test_loaders_dumpers.py b/jsonargparse_tests/test_loaders_dumpers.py
index 36c009aa..95e6ebce 100644
--- a/jsonargparse_tests/test_loaders_dumpers.py
+++ b/jsonargparse_tests/test_loaders_dumpers.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
from typing import List
from unittest.mock import patch
diff --git a/jsonargparse_tests/test_namespace.py b/jsonargparse_tests/test_namespace.py
index 3efd7fba..5b7c5b4d 100644
--- a/jsonargparse_tests/test_namespace.py
+++ b/jsonargparse_tests/test_namespace.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import argparse
import platform
diff --git a/jsonargparse_tests/test_optionals.py b/jsonargparse_tests/test_optionals.py
index 4d23cc1c..08a561e6 100644
--- a/jsonargparse_tests/test_optionals.py
+++ b/jsonargparse_tests/test_optionals.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import pytest
from jsonargparse import get_config_read_mode, set_config_read_mode
diff --git a/jsonargparse_tests/test_parameter_resolvers.py b/jsonargparse_tests/test_parameter_resolvers.py
index 585fa1b6..f683d0c8 100644
--- a/jsonargparse_tests/test_parameter_resolvers.py
+++ b/jsonargparse_tests/test_parameter_resolvers.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import calendar
import inspect
import xml.dom
@@ -331,7 +333,7 @@ def __init__(
"""
-def function_no_args_no_kwargs(pk1: str, k2: int = 1):
+def function_no_args_no_kwargs(pk1: str, k2: "int" = 1):
"""
Args:
pk1: help for pk1
@@ -407,6 +409,13 @@ def function_constant_boolean(**kwargs):
return function_with_kwargs(k1=False, **kwargs)
+def function_invalid_type(param: "invalid:" = 1): # type: ignore # noqa: F722
+ """
+ Args:
+ param: help for param
+ """
+
+
def cond_1(kc: int = 1, kn0: str = "x", kn1: str = "-"):
"""
Args:
@@ -662,6 +671,13 @@ def test_get_params_function_constant_boolean():
assert get_params(function_constant_boolean) == []
+def test_get_params_function_invalid_type(logger):
+ with capture_logs(logger):
+ params = get_params(function_invalid_type, logger=logger)
+ assert_params(params, ["param"])
+ assert params[0].annotation.replace("'", "") == "invalid:"
+
+
def test_conditional_calls_kwargs():
assert_params(
get_params(conditional_calls),
diff --git a/jsonargparse_tests/test_paths.py b/jsonargparse_tests/test_paths.py
index 75dafb5f..5b13bf0e 100644
--- a/jsonargparse_tests/test_paths.py
+++ b/jsonargparse_tests/test_paths.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import os
from calendar import Calendar
diff --git a/jsonargparse_tests/test_signatures.py b/jsonargparse_tests/test_signatures.py
index 656bb275..a2d39fdc 100644
--- a/jsonargparse_tests/test_signatures.py
+++ b/jsonargparse_tests/test_signatures.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
@@ -85,7 +88,7 @@ def test_add_class_failure_not_a_class(parser):
def test_add_class_failure_positional_without_type(parser):
with pytest.raises(ValueError) as ctx:
parser.add_class_arguments(Class2)
- ctx.match(f'Parameter "c2_a0" from "{__name__}.Class2.__init__" does not specify a type')
+ ctx.match(f"Parameter 'c2_a0' from '{__name__}.Class2.__init__' does not specify a type")
def test_add_class_without_nesting(parser):
@@ -251,6 +254,7 @@ def test_add_class_with_required_parameters(parser):
assert cfg.model == Namespace(m=0.1, n=3)
+@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="import forces future annotations in python 3.6")
def test_add_class_conditional_kwargs(parser):
from jsonargparse_tests.test_parameter_resolvers import ClassG
@@ -510,24 +514,24 @@ def test_add_function_implicit_optional(parser):
assert None is parser.parse_args(["--a1=null"]).a1
-def func_untyped_params(a1, a2=None):
- return a1
+def func_type_as_string(a2: "int"):
+ return a2
-def test_add_function_fail_untyped_true_untyped_params(parser):
- with pytest.raises(ValueError) as ctx:
- parser.add_function_arguments(func_untyped_params, fail_untyped=True)
- ctx.match('Parameter "a1" from .* does not specify a type')
+@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="not supported in python 3.6")
+def test_add_function_fail_untyped_true_str_type(parser):
+ added_args = parser.add_function_arguments(func_type_as_string, fail_untyped=True)
+ assert ["a2"] == added_args
-def func_type_as_string(a2: "int"):
- return a2
+def func_untyped_params(a1, a2=None):
+ return a1
-def test_add_function_fail_untyped_true_str_type(parser):
+def test_add_function_fail_untyped_true_untyped_params(parser):
with pytest.raises(ValueError) as ctx:
- parser.add_function_arguments(func_type_as_string, fail_untyped=True)
- ctx.match('Parameter "a2" from .* specifies the type as a string')
+ parser.add_function_arguments(func_untyped_params, fail_untyped=True)
+ ctx.match("Parameter 'a1' from .* does not specify a type")
def test_add_function_fail_untyped_false(parser):
diff --git a/jsonargparse_tests/test_stubs_resolver.py b/jsonargparse_tests/test_stubs_resolver.py
index 30fbe950..f68790b4 100644
--- a/jsonargparse_tests/test_stubs_resolver.py
+++ b/jsonargparse_tests/test_stubs_resolver.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import inspect
import sys
from calendar import Calendar, TextCalendar
@@ -46,7 +48,7 @@ def get_param_names(params):
class WithoutParent:
- ...
+ pass
@pytest.mark.parametrize(
diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py
index 9f73d1b2..82ed96fd 100644
--- a/jsonargparse_tests/test_subclasses.py
+++ b/jsonargparse_tests/test_subclasses.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import json
import textwrap
import warnings
diff --git a/jsonargparse_tests/test_subcommands.py b/jsonargparse_tests/test_subcommands.py
index e96af28a..3e4421cf 100644
--- a/jsonargparse_tests/test_subcommands.py
+++ b/jsonargparse_tests/test_subcommands.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import warnings
from pathlib import Path
diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py
index 5b45bcbb..fd73308d 100644
--- a/jsonargparse_tests/test_typehints.py
+++ b/jsonargparse_tests/test_typehints.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import random
import sys
import time
@@ -775,6 +777,12 @@ def test_action_typehint_unsupported_type(typehint):
ctx.match("Unsupported type hint")
+def test_action_typehint_none_type_error():
+ with pytest.raises(ValueError) as ctx:
+ ActionTypeHint(typehint=None)
+ ctx.match("Expected typehint keyword argument")
+
+
@pytest.mark.parametrize(
["typehint", "ref_type", "expected"],
[
diff --git a/jsonargparse_tests/test_typing.py b/jsonargparse_tests/test_typing.py
index d395cb31..d779b441 100644
--- a/jsonargparse_tests/test_typing.py
+++ b/jsonargparse_tests/test_typing.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import pickle
import random
import uuid
diff --git a/jsonargparse_tests/test_util.py b/jsonargparse_tests/test_util.py
index 6de82132..1654fc54 100644
--- a/jsonargparse_tests/test_util.py
+++ b/jsonargparse_tests/test_util.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
import os
import pathlib
@@ -566,7 +568,7 @@ def get_foo(cls) -> "Foo":
def closure_get_foo():
- def get_foo() -> "Foo":
+ def get_foo() -> Foo:
return Foo()
return get_foo
@@ -593,7 +595,7 @@ def get_unknown() -> "Unknown": # type: ignore # noqa: F821
def test_invalid_class_from_function():
with pytest.raises(ValueError) as ctx:
class_from_function(get_unknown)
- ctx.match("Unable to dereference None the return type")
+ ctx.match("Unable to dereference '?Unknown'?, the return type of")
def get_calendar(a1: str, a2: int = 2) -> Calendar:
diff --git a/pyproject.toml b/pyproject.toml
index 32c9f2a9..e9735e79 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -206,7 +206,7 @@ multi_line_output = 3
[tool.tox]
legacy_tox_ini = """
[tox]
-envlist = py{37,38,39,310,311}-{all,no}-extras,pypy3,omegaconf
+envlist = py{37,38,39,310,311}-{all,no}-extras,omegaconf
skip_missing_interpreters = true
[testenv]