From 9d580148104809a272c7cfd43e151b29cacf0b46 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 6 Jun 2023 06:57:02 +0200 Subject: [PATCH] - Added support for postponed evaluation of annotations PEP 563 (#120). - Backport types in python<=3.9 to support PEP 585 and 604 for postponed evaluation of annotations (#120). --- .circleci/config.yml | 47 ++++-- .pre-commit-config.yaml | 36 +++++ CHANGELOG.rst | 14 ++ README.rst | 7 + jsonargparse/_backports.py | 138 ++++++++++++++++-- jsonargparse/_stubs_resolver.py | 64 +------- jsonargparse/link_arguments.py | 4 +- jsonargparse/parameter_resolvers.py | 3 + jsonargparse/signatures.py | 24 ++- jsonargparse/util.py | 12 +- jsonargparse_tests/test_actions.py | 2 + jsonargparse_tests/test_argcomplete.py | 2 + jsonargparse_tests/test_backports.py | 105 +++++++++++++ jsonargparse_tests/test_cli.py | 2 + jsonargparse_tests/test_core.py | 2 + jsonargparse_tests/test_dataclass_like.py | 33 +++-- jsonargparse_tests/test_deprecated.py | 2 + jsonargparse_tests/test_formatters.py | 2 + jsonargparse_tests/test_jsonnet.py | 2 + jsonargparse_tests/test_jsonschema.py | 2 + jsonargparse_tests/test_link_arguments.py | 2 + jsonargparse_tests/test_loaders_dumpers.py | 2 + jsonargparse_tests/test_namespace.py | 2 + jsonargparse_tests/test_optionals.py | 2 + .../test_parameter_resolvers.py | 18 ++- jsonargparse_tests/test_paths.py | 2 + jsonargparse_tests/test_signatures.py | 28 ++-- jsonargparse_tests/test_stubs_resolver.py | 4 +- jsonargparse_tests/test_subclasses.py | 2 + jsonargparse_tests/test_subcommands.py | 2 + jsonargparse_tests/test_typehints.py | 8 + jsonargparse_tests/test_typing.py | 2 + jsonargparse_tests/test_util.py | 6 +- pyproject.toml | 2 +- 34 files changed, 448 insertions(+), 137 deletions(-) create mode 100644 jsonargparse_tests/test_backports.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 468ba389..e8b0d7f6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,6 +32,8 @@ jobs: python3 -m jsonargparse_tests coverage xml coverage_py$py.xml pip3 install $(ls ./dist/*.whl)[test,all] python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml + sed -i '/^from __future__ import annotations$/d' jsonargparse_tests/test_*.py + python3 -m jsonargparse_tests coverage xml coverage_py${py}_types.xml - persist_to_workspace: root: . paths: @@ -53,9 +55,26 @@ jobs: docker: - image: cimg/python:3.7 test-py36: - <<: *test-py38 docker: - image: cimg/python:3.6 + steps: + - attach_workspace: + at: . + - run: + name: Run unit tests + command: | + py=$(python3 --version | sed -r 's|.* 3\.([0-9]+)\..*|3.\1|') + sed -i '/^from __future__ import annotations$/d' jsonargparse_tests/test_*.py + virtualenv -p python3 venv$py + . venv$py/bin/activate + pip3 install $(ls ./dist/*.whl)[test-no-urls] + python3 -m jsonargparse_tests coverage xml coverage_py$py.xml + pip3 install $(ls ./dist/*.whl)[test,all] + python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml + - persist_to_workspace: + root: . + paths: + - ./coverage_*.xml codecov: docker: - image: cimg/python:3.8 @@ -66,28 +85,26 @@ jobs: - run: name: Code coverage command: | - #for py in 3.6 3.7 3.8 3.9 3.10 3.11; do - for py in 3.6 3.7 3.8 3.9 3.10; do - bash <(curl -s https://codecov.io/bash) \ - -Z \ - -t $CODECOV_TOKEN_JSONARGPARSE \ - -F py$py \ - -f coverage_py${py}.xml - bash <(curl -s https://codecov.io/bash) \ - -Z \ - -t $CODECOV_TOKEN_JSONARGPARSE \ - -F py${py}_all \ - -f coverage_py${py}_all.xml + for py in 3.6 3.7 3.8 3.9 3.10 3.11; do + for suffix in "" "_all" "_types"; do + bash <(curl -s https://codecov.io/bash) \ + -Z \ + -t $CODECOV_TOKEN_JSONARGPARSE \ + -F py${py}${suffix} \ + -f coverage_py${py}${suffix}.xml + done done publish-pypi: docker: - - image: mauvilsa/docker-twine:1.11.0 + - image: cimg/python:3.10 steps: - attach_workspace: at: . - run: name: Publish Release on PyPI - command: twine upload --username __token__ --password "${PYPI_TOKEN}" ./dist/*.whl ./dist/*.tar.gz + command: | + pip3 install -U twine + twine upload --username __token__ --password "${PYPI_TOKEN}" ./dist/*.whl ./dist/*.tar.gz workflows: version: 2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72a2c8fe..144d196a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -103,6 +103,42 @@ repos: pass_filenames: false verbose: true + - id: test-py36 + name: test-py36 + entry: bash -c ' + if [ "$(which python3.6)" = "" ]; then + echo "$(tput setaf 6) Skipped, python 3.6 not found $(tput sgr0)"; + else + TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX); + cleanup () { rm -rf "$TEST_DIR"; }; + trap cleanup EXIT; + ./setup.py bdist_wheel; + python3.6 -m venv "$TEST_DIR/venv36"; + . "$TEST_DIR/venv36/bin/activate"; + pip3 install "$(ls ./dist/*.whl | tail -n 1)[all,test,test-no-urls]"; + pip3 install pytest pytest-subtests; + rsync -a jsonargparse_tests/*.py "$TEST_DIR"; + cd "$TEST_DIR"; + rm test_backports.py; + sed -i "/^from __future__ import annotations$/d" *.py; + pytest; + fi' + language: system + pass_filenames: false + + - id: test-without-future-annotations + name: test-without-future-annotations + entry: bash -c ' + TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX); + cleanup () { rm -rf "$TEST_DIR"; }; + trap cleanup EXIT; + rsync -a jsonargparse_tests/*.py "$TEST_DIR"; + cd "$TEST_DIR"; + sed -i "/^from __future__ import annotations$/d" *.py; + pytest $TEST_DIR;' + language: system + pass_filenames: false + - id: doctest name: sphinx-build -M doctest sphinx sphinx/_build sphinx/index.rst entry: bash -c ' diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 666763b4..3b290d64 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,20 @@ The semantic versioning only considers the public API as described in paths are considered internals and can change in minor and patch releases. +v4.22.0 (2023-06-??) +-------------------- + +Added +^^^^^ +- Support for postponed evaluation of annotations PEP `563 + `__ ``from __future__ import annotations`` + (`#120 `__). +- Backport types in python<=3.9 to support PEP `585 + `__ and `604 + `__ for postponed evaluation of annotations + (`#120 `__). + + v4.21.2 (2023-06-??) -------------------- diff --git a/README.rst b/README.rst index fae71275..46dd8cd1 100644 --- a/README.rst +++ b/README.rst @@ -425,6 +425,13 @@ Some notes about this support are: nesting it is meant child types inside ``List``, ``Dict``, etc. There is no limit in nesting depth. +- Postponed evaluation of types PEP `563 `__ + (i.e. ``from __future__ import annotations``) is supported. Also supported on + ``python<=3.9`` are PEP `585 `__ (i.e. + ``list[], dict[], ...`` instead of ``List[], Dict[], + ...``) and `604 `__ (i.e. `` | + `` instead of ``Union[, ]``). + - Fully supported types are: ``str``, ``bool`` (more details in :ref:`boolean-arguments`), ``int``, ``float``, ``complex``, ``bytes``/``bytearray`` (Base64 encoding), ``List`` (more details in diff --git a/jsonargparse/_backports.py b/jsonargparse/_backports.py index 0e952ddf..94a329f6 100644 --- a/jsonargparse/_backports.py +++ b/jsonargparse/_backports.py @@ -1,24 +1,28 @@ import ast +import inspect +import sys +import textwrap +import typing as T from collections import namedtuple from copy import deepcopy -from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union + +from .optionals import typing_extensions_import +from .util import unique var_map = namedtuple("var_map", "name value") none_map = var_map(name="NoneType", value=type(None)) -union_map = var_map(name="Union", value=Union) +union_map = var_map(name="Union", value=T.Union) pep585_map = { - "dict": var_map(name="Dict", value=Dict), - "frozenset": var_map(name="FrozenSet", value=FrozenSet), - "list": var_map(name="List", value=List), - "set": var_map(name="Set", value=Set), - "tuple": var_map(name="Tuple", value=Tuple), - "type": var_map(name="Type", value=Type), + "dict": var_map(name="Dict", value=T.Dict), + "frozenset": var_map(name="FrozenSet", value=T.FrozenSet), + "list": var_map(name="List", value=T.List), + "set": var_map(name="Set", value=T.Set), + "tuple": var_map(name="Tuple", value=T.Tuple), + "type": var_map(name="Type", value=T.Type), } class BackportTypeHints(ast.NodeTransformer): - _typing = __import__("typing") - def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: if isinstance(node.value, ast.Name) and node.value.id in pep585_map: value = self.new_name_load(pep585_map[node.value.id]) @@ -30,13 +34,13 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: ctx=ast.Load(), ) - def visit_Constant(self, node: ast.Constant) -> Union[ast.Constant, ast.Name]: + def visit_Constant(self, node: ast.Constant) -> T.Union[ast.Constant, ast.Name]: if node.value is None: return self.new_name_load(none_map) return node - def visit_BinOp(self, node: ast.BinOp) -> Union[ast.BinOp, ast.Subscript]: - out_node: Union[ast.BinOp, ast.Subscript] = node + def visit_BinOp(self, node: ast.BinOp) -> T.Union[ast.BinOp, ast.Subscript]: + out_node: T.Union[ast.BinOp, ast.Subscript] = node if isinstance(node.op, ast.BitOr): elts: list = [] self.append_union_elts(node.left, elts) @@ -66,8 +70,112 @@ def new_name_load(self, var: var_map) -> ast.Name: def backport(self, input_ast: ast.AST, exec_vars: dict) -> ast.AST: for key, value in exec_vars.items(): if getattr(value, "__module__", "") == "collections.abc": - if hasattr(self._typing, key): - exec_vars[key] = getattr(self._typing, key) + if hasattr(T, key): + exec_vars[key] = getattr(T, key) self.exec_vars = exec_vars backport_ast = self.visit(deepcopy(input_ast)) return ast.fix_missing_locations(backport_ast) + + +class NamesVisitor(ast.NodeVisitor): + def visit_Name(self, node: ast.Name) -> None: + self.names_found.append(node.id) + + def find(self, node: ast.AST) -> list: + self.names_found: T.List[str] = [] + self.visit(node) + self.names_found = unique(self.names_found) + return self.names_found + + +def get_arg_type(arg_ast, aliases): + type_ast = ast.parse("___arg_type___ = 0") + type_ast.body[0].value = arg_ast.annotation + exec_vars = {} + bad_aliases = {} + add_asts = False + for name in NamesVisitor().find(arg_ast.annotation): + value = aliases[name] + if isinstance(value, tuple): + value = value[1] + if isinstance(value, Exception): + bad_aliases[name] = value + elif isinstance(value, ast.AST): + add_asts = True + else: + exec_vars[name] = value + if add_asts: + body = [] + for name, (_, value) in aliases.items(): + if isinstance(value, ast.AST): + body.append(ast.fix_missing_locations(value)) + elif not isinstance(value, Exception): + exec_vars[name] = value + type_ast.body = body + type_ast.body + if "TypeAlias" not in exec_vars: + type_alias = typing_extensions_import("TypeAlias") + if type_alias: + exec_vars["TypeAlias"] = type_alias + if sys.version_info < (3, 10): + backporter = BackportTypeHints() + type_ast = backporter.backport(type_ast, exec_vars) + try: + exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars) + except NameError as ex: + ex_from = None + for name, alias_exception in bad_aliases.items(): + if str(ex) == f"name '{name}' is not defined": + ex_from = alias_exception + break + raise ex from ex_from + return exec_vars["___arg_type___"] + + +def get_type_hints(obj: T.Any, globalns: T.Optional[dict] = None, localns: T.Optional[dict] = None) -> dict: + try: + return T.get_type_hints(obj, globalns, localns) + except Exception as ex1: + try: + source = textwrap.dedent(inspect.getsource(obj)) + tree = ast.parse(source) + assert isinstance(tree, ast.Module) and len(tree.body) == 1 + node = tree.body[0] + + aliases = __builtins__.copy() # type: ignore + if globalns is None: + globalns = obj.__globals__ + aliases.update(globalns) + if localns is not None: + aliases.update(localns) + + types = {} + for arg_ast in node.args.args + node.args.kwonlyargs: # type: ignore + if not arg_ast.annotation: + continue + name = arg_ast.arg + types[name] = get_arg_type(arg_ast, aliases) + + return types + except Exception as ex2: + raise type(ex1)(f"{repr(ex1)} + {repr(ex2)}") from ex2 + + +def evaluate_postponed_annotations(params, component, logger): + if sys.version_info[:2] == (3, 6) or not ( + params and any(isinstance(p.annotation, (str, T.ForwardRef)) for p in params) + ): + return + try: + if sys.version_info < (3, 10): + types = get_type_hints(component) + else: + types = T.get_type_hints(component) + except Exception as ex: + logger.debug(f"Unable to evaluate types for {component}", exc_info=ex) + return + for param in params: + if param.name in types: + param_type = types[param.name] + if isinstance(param_type, T.ForwardRef): + param_type = param_type._evaluate(component.__globals__, {}) + param.annotation = param_type diff --git a/jsonargparse/_stubs_resolver.py b/jsonargparse/_stubs_resolver.py index 164ae05a..01f305ed 100644 --- a/jsonargparse/_stubs_resolver.py +++ b/jsonargparse/_stubs_resolver.py @@ -5,14 +5,10 @@ from copy import deepcopy from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple -from .optionals import ( - import_typeshed_client, - typeshed_client_support, - typing_extensions_import, -) -from .util import unique +from ._backports import NamesVisitor, get_arg_type +from .optionals import import_typeshed_client, typeshed_client_support if TYPE_CHECKING: # pragma: no cover import typeshed_client as tc @@ -32,17 +28,6 @@ def import_module_or_none(path: str): return None -class NamesVisitor(ast.NodeVisitor): - def visit_Name(self, node: ast.Name) -> None: - self.names_found.append(node.id) - - def find(self, node: ast.AST) -> list: - self.names_found: List[str] = [] - self.visit(node) - self.names_found = unique(self.names_found) - return self.names_found - - class ImportsVisitor(ast.NodeVisitor): def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if node.level: @@ -262,49 +247,6 @@ def alias_is_unique(aliases, name, source, value): return True -def get_arg_type(arg_ast, aliases): - type_ast = ast.parse("___arg_type___ = 0") - type_ast.body[0].value = arg_ast.annotation - exec_vars = {} - bad_aliases = {} - add_asts = False - for name in NamesVisitor().find(arg_ast.annotation): - _, value = aliases[name] - if isinstance(value, Exception): - bad_aliases[name] = value - elif isinstance(value, ast.AST): - add_asts = True - else: - exec_vars[name] = value - if add_asts: - body = [] - for name, (_, value) in aliases.items(): - if isinstance(value, ast.AST): - body.append(ast.fix_missing_locations(value)) - elif not isinstance(value, Exception): - exec_vars[name] = value - type_ast.body = body + type_ast.body - if "TypeAlias" not in exec_vars: - type_alias = typing_extensions_import("TypeAlias") - if type_alias: - exec_vars["TypeAlias"] = type_alias - if sys.version_info < (3, 10): - from ._backports import BackportTypeHints - - backporter = BackportTypeHints() - type_ast = backporter.backport(type_ast, exec_vars) - try: - exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars) - except NameError as ex: - ex_from = None - for name, alias_exception in bad_aliases.items(): - if str(ex) == f"name '{name}' is not defined": - ex_from = alias_exception - break - raise ex from ex_from - return exec_vars["___arg_type___"] - - def get_stub_types(params, component, parent, logger) -> Optional[Dict[str, Any]]: if not typeshed_client_support: return None diff --git a/jsonargparse/link_arguments.py b/jsonargparse/link_arguments.py index a7ee315f..9bd78fdf 100644 --- a/jsonargparse/link_arguments.py +++ b/jsonargparse/link_arguments.py @@ -1,6 +1,5 @@ """Code related to argument linking.""" -import inspect import re from argparse import SUPPRESS from argparse import Action as ArgparseAction @@ -18,6 +17,7 @@ filter_default_actions, ) from .namespace import Namespace, split_key_leaf +from .parameter_resolvers import get_signature_parameters from .type_checking import ArgumentParser, _ArgumentGroup __all__ = ["ArgumentLinking"] @@ -271,7 +271,7 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None: value = value.as_dict() else: # Automatic namespace to dict based on compute_fn param type hint - params = list(inspect.signature(action.compute_fn).parameters.values()) + params = get_signature_parameters(action.compute_fn) for n, param in enumerate(params): if ( n < len(args) diff --git a/jsonargparse/parameter_resolvers.py b/jsonargparse/parameter_resolvers.py index e2b286ff..a4e8614f 100644 --- a/jsonargparse/parameter_resolvers.py +++ b/jsonargparse/parameter_resolvers.py @@ -10,6 +10,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from ._backports import evaluate_postponed_annotations from ._common import is_dataclass_like, is_subclass from ._stubs_resolver import get_stub_types from .optionals import parse_docs @@ -275,6 +276,7 @@ def get_signature_parameters_and_indexes(component, parent, logger): component=component, **{a: getattr(param, a) for a in parameter_attributes}, ) + evaluate_postponed_annotations(params, signature_source, logger) stubs = get_stub_types(params, signature_source, parent, logger) return params, args_idx, kwargs_idx, doc_params, stubs @@ -801,6 +803,7 @@ def get_parameters_from_pydantic( component=function_or_class, ) ) + evaluate_postponed_annotations(params, function_or_class, logger) return params diff --git a/jsonargparse/signatures.py b/jsonargparse/signatures.py index 2b5670fe..7c9aa070 100644 --- a/jsonargparse/signatures.py +++ b/jsonargparse/signatures.py @@ -15,7 +15,12 @@ get_parameter_origins, get_signature_parameters, ) -from .typehints import ActionTypeHint, LazyInitBaseClass, is_optional +from .typehints import ( + ActionTypeHint, + LazyInitBaseClass, + callable_instances, + is_optional, +) from .typing import register_pydantic_type from .util import LoggerProperty, get_import_path, iter_to_set_str @@ -176,9 +181,14 @@ def add_function_arguments( if not callable(function): raise ValueError('Expected "function" argument to be a callable object.') + method_name = None + if hasattr(function, "__class__") and callable_instances(function.__class__): + function = function.__class__ + method_name = "__call__" + return self._add_signature_arguments( function, - None, + method_name, nested_key, as_group, as_positional, @@ -366,12 +376,10 @@ def _add_signature_parameter( action.sub_add_kwargs["skip"] = subclass_skip added_args.append(dest) elif is_required and fail_untyped: - msg = f'With fail_untyped=True, all mandatory parameters must have a supported type. Parameter "{name}" from "{src}" ' - if isinstance(annotation, str): - msg += "specifies the type as a string. Types as a string and `from __future__ import annotations` is currently not supported." - else: - msg += "does not specify a type." - raise ValueError(msg) + raise ValueError( + "With fail_untyped=True, all mandatory parameters must have a supported" + f" type. Parameter '{name}' from '{src}' does not specify a type." + ) def add_dataclass_arguments( self, diff --git a/jsonargparse/util.py b/jsonargparse/util.py index 382f3db9..86e72780 100644 --- a/jsonargparse/util.py +++ b/jsonargparse/util.py @@ -28,6 +28,7 @@ Type, TypeVar, Union, + get_type_hints, ) from ._common import parser_capture, parser_context @@ -369,10 +370,13 @@ def class_from_function(func: Callable[..., ClassType]) -> Type[ClassType]: """ func_return = inspect.signature(func).return_annotation if isinstance(func_return, str): - caller_frame = inspect.currentframe().f_back # type: ignore - func_return = caller_frame.f_locals.get(func_return) or caller_frame.f_globals.get(func_return) # type: ignore - if func_return is None: - raise ValueError(f"Unable to dereference {func_return} the return type of {func}.") + try: + func_return = get_type_hints(func)["return"] + if sys.version_info[:2] != (3, 6) and isinstance(func_return, __import__("typing").ForwardRef): + func_return = func_return._evaluate(func.__globals__, {}) + except Exception as ex: + func_return = inspect.signature(func).return_annotation + raise ValueError(f"Unable to dereference {func_return}, the return type of {func}: {ex}") from ex @wraps(func) def __new__(cls, *args, **kwargs): diff --git a/jsonargparse_tests/test_actions.py b/jsonargparse_tests/test_actions.py index e1a26ea2..326f9d6f 100644 --- a/jsonargparse_tests/test_actions.py +++ b/jsonargparse_tests/test_actions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import pytest diff --git a/jsonargparse_tests/test_argcomplete.py b/jsonargparse_tests/test_argcomplete.py index 68915888..c6fbd545 100644 --- a/jsonargparse_tests/test_argcomplete.py +++ b/jsonargparse_tests/test_argcomplete.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys from contextlib import ExitStack, contextmanager diff --git a/jsonargparse_tests/test_backports.py b/jsonargparse_tests/test_backports.py new file mode 100644 index 00000000..7c256387 --- /dev/null +++ b/jsonargparse_tests/test_backports.py @@ -0,0 +1,105 @@ +from __future__ import annotations # keep + +import sys +from random import Random +from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union + +import pytest + +from jsonargparse._backports import get_type_hints +from jsonargparse.parameter_resolvers import get_signature_parameters as get_params + + +@pytest.fixture(autouse=True) +def skip_if_python_older_than_3_10(): + if sys.version_info >= (3, 10, 0): + pytest.skip("python<3.10 is required") + + +def function_pep585_dict(p1: dict[str, int], p2: dict[int, str] = {1: "a"}): + return p1 + + +def function_pep585_list(p1: list[str], p2: list[float] = [0.1, 2.3]): + return p1 + + +def function_pep585_set(p1: set[str], p2: set[int] = {1, 2}): + return p1 + + +def function_pep585_frozenset(p1: frozenset[str], p2: frozenset[int] = frozenset(range(3))): + return p1 + + +def function_pep585_tuple(p1: tuple[str, float], p2: tuple[int, ...] = (1, 2)): + return p1 + + +def function_pep585_type(p1: type[Random], p2: type[Random] = Random): + return p1 + + +@pytest.mark.skipif(sys.version_info >= (3, 9, 0), reason="python<3.9 is required") +@pytest.mark.parametrize( + ["function", "expected"], + [ + (function_pep585_dict, {"p1": Dict[str, int], "p2": Dict[int, str]}), + (function_pep585_list, {"p1": List[str], "p2": List[float]}), + (function_pep585_set, {"p1": Set[str], "p2": Set[int]}), + (function_pep585_frozenset, {"p1": FrozenSet[str], "p2": FrozenSet[int]}), + (function_pep585_tuple, {"p1": Tuple[str, float], "p2": Tuple[int, ...]}), + (function_pep585_type, {"p1": Type[Random], "p2": Type[Random]}), + ], +) +def test_pep585(function, expected): + types = get_type_hints(function) + assert types == expected + + +def function_pep604(p1: str | None, p2: int | float | bool = 1): + return p1 + + +def test_pep604(): + types = get_type_hints(function_pep604, function_pep604.__globals__, {}) + assert types == {"p1": Union[str, None], "p2": Union[int, float, bool]} + + +class NeedsBackport: + def __init__(self, p1: list | set): + self.p1 = p1 + + @staticmethod + def static_method(p1: str | int): + return p1 + + @classmethod + def class_method(cls, p1: float | None): + return p1 + + +@pytest.mark.parametrize( + ["method", "expected"], + [ + (NeedsBackport.__init__, {"p1": Union[list, set]}), + (NeedsBackport.static_method, {"p1": Union[str, int]}), + (NeedsBackport.class_method, {"p1": Union[float, None]}), + ], +) +def test_methods(method, expected): + types = get_type_hints(method) + assert types == expected + + +def function_undefined_type(p1: not_defined | None): # type: ignore # noqa: F821 + return p1 + + +def test_undefined_type(): + with pytest.raises(Exception) as ctx: + get_type_hints(function_undefined_type) + ctx.match("name 'not_defined' is not defined") + + params = get_params(function_undefined_type) + assert params[0].annotation == "not_defined | None" diff --git a/jsonargparse_tests/test_cli.py b/jsonargparse_tests/test_cli.py index 48db4bc0..6927cd54 100644 --- a/jsonargparse_tests/test_cli.py +++ b/jsonargparse_tests/test_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys from contextlib import redirect_stderr, redirect_stdout, suppress diff --git a/jsonargparse_tests/test_core.py b/jsonargparse_tests/test_core.py index 0b3733be..4b49cc8d 100644 --- a/jsonargparse_tests/test_core.py +++ b/jsonargparse_tests/test_core.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import pickle diff --git a/jsonargparse_tests/test_dataclass_like.py b/jsonargparse_tests/test_dataclass_like.py index c5182c97..f7e346ef 100644 --- a/jsonargparse_tests/test_dataclass_like.py +++ b/jsonargparse_tests/test_dataclass_like.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import sys from typing import Any, Dict, List, Optional, Union @@ -418,7 +420,7 @@ def none(x): return x -@pytest.mark.skipif(sys.version_info == (3, 6), reason="pydantic not supported in python 3.6") +@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="pydantic not supported in python 3.6") @pytest.mark.skipif(not pydantic_support, reason="pydantic package is required") class TestPydantic: def test_dataclass(self, parser): @@ -448,21 +450,22 @@ def test_field_description(self, parser): assert "p1 help (required, type: str)" in help_str assert "p2 help (type: int, default: 2)" in help_str - def test_pydantic_types(self, subtests): - for valid_value, invalid_value, cast, pydantic_type in [ - ("abc", "a", none, pydantic.constr(min_length=2, max_length=4)), - (2, 0, none, pydantic.conint(ge=1)), - (-1.0, 1.0, none, pydantic.confloat(lt=0.0)), - ([1], [], none, pydantic.conlist(int, min_items=1)), - ([], [3, 4], none, pydantic.conlist(int, max_items=1)), - ([1], "x", list, pydantic.conset(int, min_items=1)), - ("http://abc.es", "-", none, pydantic.HttpUrl), - ("127.0.0.1", "0", str, pydantic.IPvAnyAddress), - ]: + def test_pydantic_types(self, subtests, monkeypatch): + for num, (valid_value, invalid_value, cast, pydantic_type) in enumerate( + [ + ("abc", "a", none, pydantic.constr(min_length=2, max_length=4)), + (2, 0, none, pydantic.conint(ge=1)), + (-1.0, 1.0, none, pydantic.confloat(lt=0.0)), + ([1], [], none, pydantic.conlist(int, min_items=1)), + ([], [3, 4], none, pydantic.conlist(int, max_items=1)), + ([1], "x", list, pydantic.conset(int, min_items=1)), + ("http://abc.es", "-", none, pydantic.HttpUrl), + ("127.0.0.1", "0", str, pydantic.IPvAnyAddress), + ] + ): with subtests.test(f"type={pydantic_type.__name__} valid={valid_value} invalid={invalid_value}"): - - class Model(pydantic.BaseModel): - param: pydantic_type + Model = pydantic.create_model(f"Model{num}", param=(pydantic_type, ...)) + monkeypatch.setitem(Model.__init__.__globals__, "pydantic_type", pydantic_type) parser = ArgumentParser(exit_on_error=False) parser.add_argument("--model", type=Model) diff --git a/jsonargparse_tests/test_deprecated.py b/jsonargparse_tests/test_deprecated.py index 9995111d..c8019cfc 100644 --- a/jsonargparse_tests/test_deprecated.py +++ b/jsonargparse_tests/test_deprecated.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pathlib from calendar import Calendar diff --git a/jsonargparse_tests/test_formatters.py b/jsonargparse_tests/test_formatters.py index 01aa2f63..f92bb3f8 100644 --- a/jsonargparse_tests/test_formatters.py +++ b/jsonargparse_tests/test_formatters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path from typing import Tuple diff --git a/jsonargparse_tests/test_jsonnet.py b/jsonargparse_tests/test_jsonnet.py index fec2949c..9fcf9b7a 100644 --- a/jsonargparse_tests/test_jsonnet.py +++ b/jsonargparse_tests/test_jsonnet.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re from pathlib import Path diff --git a/jsonargparse_tests/test_jsonschema.py b/jsonargparse_tests/test_jsonschema.py index 59e121dd..7af247a5 100644 --- a/jsonargparse_tests/test_jsonschema.py +++ b/jsonargparse_tests/test_jsonschema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re from importlib.util import find_spec diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py index 6513b4c3..c2f4c31e 100644 --- a/jsonargparse_tests/test_link_arguments.py +++ b/jsonargparse_tests/test_link_arguments.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from calendar import Calendar, TextCalendar from typing import Any, List, Mapping, Optional, Union diff --git a/jsonargparse_tests/test_loaders_dumpers.py b/jsonargparse_tests/test_loaders_dumpers.py index 36c009aa..95e6ebce 100644 --- a/jsonargparse_tests/test_loaders_dumpers.py +++ b/jsonargparse_tests/test_loaders_dumpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from typing import List from unittest.mock import patch diff --git a/jsonargparse_tests/test_namespace.py b/jsonargparse_tests/test_namespace.py index 3efd7fba..5b7c5b4d 100644 --- a/jsonargparse_tests/test_namespace.py +++ b/jsonargparse_tests/test_namespace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import platform diff --git a/jsonargparse_tests/test_optionals.py b/jsonargparse_tests/test_optionals.py index 4d23cc1c..08a561e6 100644 --- a/jsonargparse_tests/test_optionals.py +++ b/jsonargparse_tests/test_optionals.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from jsonargparse import get_config_read_mode, set_config_read_mode diff --git a/jsonargparse_tests/test_parameter_resolvers.py b/jsonargparse_tests/test_parameter_resolvers.py index 585fa1b6..f683d0c8 100644 --- a/jsonargparse_tests/test_parameter_resolvers.py +++ b/jsonargparse_tests/test_parameter_resolvers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import calendar import inspect import xml.dom @@ -331,7 +333,7 @@ def __init__( """ -def function_no_args_no_kwargs(pk1: str, k2: int = 1): +def function_no_args_no_kwargs(pk1: str, k2: "int" = 1): """ Args: pk1: help for pk1 @@ -407,6 +409,13 @@ def function_constant_boolean(**kwargs): return function_with_kwargs(k1=False, **kwargs) +def function_invalid_type(param: "invalid:" = 1): # type: ignore # noqa: F722 + """ + Args: + param: help for param + """ + + def cond_1(kc: int = 1, kn0: str = "x", kn1: str = "-"): """ Args: @@ -662,6 +671,13 @@ def test_get_params_function_constant_boolean(): assert get_params(function_constant_boolean) == [] +def test_get_params_function_invalid_type(logger): + with capture_logs(logger): + params = get_params(function_invalid_type, logger=logger) + assert_params(params, ["param"]) + assert params[0].annotation.replace("'", "") == "invalid:" + + def test_conditional_calls_kwargs(): assert_params( get_params(conditional_calls), diff --git a/jsonargparse_tests/test_paths.py b/jsonargparse_tests/test_paths.py index 75dafb5f..5b13bf0e 100644 --- a/jsonargparse_tests/test_paths.py +++ b/jsonargparse_tests/test_paths.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os from calendar import Calendar diff --git a/jsonargparse_tests/test_signatures.py b/jsonargparse_tests/test_signatures.py index 656bb275..a2d39fdc 100644 --- a/jsonargparse_tests/test_signatures.py +++ b/jsonargparse_tests/test_signatures.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch @@ -85,7 +88,7 @@ def test_add_class_failure_not_a_class(parser): def test_add_class_failure_positional_without_type(parser): with pytest.raises(ValueError) as ctx: parser.add_class_arguments(Class2) - ctx.match(f'Parameter "c2_a0" from "{__name__}.Class2.__init__" does not specify a type') + ctx.match(f"Parameter 'c2_a0' from '{__name__}.Class2.__init__' does not specify a type") def test_add_class_without_nesting(parser): @@ -251,6 +254,7 @@ def test_add_class_with_required_parameters(parser): assert cfg.model == Namespace(m=0.1, n=3) +@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="import forces future annotations in python 3.6") def test_add_class_conditional_kwargs(parser): from jsonargparse_tests.test_parameter_resolvers import ClassG @@ -510,24 +514,24 @@ def test_add_function_implicit_optional(parser): assert None is parser.parse_args(["--a1=null"]).a1 -def func_untyped_params(a1, a2=None): - return a1 +def func_type_as_string(a2: "int"): + return a2 -def test_add_function_fail_untyped_true_untyped_params(parser): - with pytest.raises(ValueError) as ctx: - parser.add_function_arguments(func_untyped_params, fail_untyped=True) - ctx.match('Parameter "a1" from .* does not specify a type') +@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="not supported in python 3.6") +def test_add_function_fail_untyped_true_str_type(parser): + added_args = parser.add_function_arguments(func_type_as_string, fail_untyped=True) + assert ["a2"] == added_args -def func_type_as_string(a2: "int"): - return a2 +def func_untyped_params(a1, a2=None): + return a1 -def test_add_function_fail_untyped_true_str_type(parser): +def test_add_function_fail_untyped_true_untyped_params(parser): with pytest.raises(ValueError) as ctx: - parser.add_function_arguments(func_type_as_string, fail_untyped=True) - ctx.match('Parameter "a2" from .* specifies the type as a string') + parser.add_function_arguments(func_untyped_params, fail_untyped=True) + ctx.match("Parameter 'a1' from .* does not specify a type") def test_add_function_fail_untyped_false(parser): diff --git a/jsonargparse_tests/test_stubs_resolver.py b/jsonargparse_tests/test_stubs_resolver.py index 30fbe950..f68790b4 100644 --- a/jsonargparse_tests/test_stubs_resolver.py +++ b/jsonargparse_tests/test_stubs_resolver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import sys from calendar import Calendar, TextCalendar @@ -46,7 +48,7 @@ def get_param_names(params): class WithoutParent: - ... + pass @pytest.mark.parametrize( diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index 9f73d1b2..82ed96fd 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import textwrap import warnings diff --git a/jsonargparse_tests/test_subcommands.py b/jsonargparse_tests/test_subcommands.py index e96af28a..3e4421cf 100644 --- a/jsonargparse_tests/test_subcommands.py +++ b/jsonargparse_tests/test_subcommands.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import warnings from pathlib import Path diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 5b45bcbb..fd73308d 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random import sys import time @@ -775,6 +777,12 @@ def test_action_typehint_unsupported_type(typehint): ctx.match("Unsupported type hint") +def test_action_typehint_none_type_error(): + with pytest.raises(ValueError) as ctx: + ActionTypeHint(typehint=None) + ctx.match("Expected typehint keyword argument") + + @pytest.mark.parametrize( ["typehint", "ref_type", "expected"], [ diff --git a/jsonargparse_tests/test_typing.py b/jsonargparse_tests/test_typing.py index d395cb31..d779b441 100644 --- a/jsonargparse_tests/test_typing.py +++ b/jsonargparse_tests/test_typing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import random import uuid diff --git a/jsonargparse_tests/test_util.py b/jsonargparse_tests/test_util.py index 6de82132..1654fc54 100644 --- a/jsonargparse_tests/test_util.py +++ b/jsonargparse_tests/test_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import pathlib @@ -566,7 +568,7 @@ def get_foo(cls) -> "Foo": def closure_get_foo(): - def get_foo() -> "Foo": + def get_foo() -> Foo: return Foo() return get_foo @@ -593,7 +595,7 @@ def get_unknown() -> "Unknown": # type: ignore # noqa: F821 def test_invalid_class_from_function(): with pytest.raises(ValueError) as ctx: class_from_function(get_unknown) - ctx.match("Unable to dereference None the return type") + ctx.match("Unable to dereference '?Unknown'?, the return type of") def get_calendar(a1: str, a2: int = 2) -> Calendar: diff --git a/pyproject.toml b/pyproject.toml index 32c9f2a9..e9735e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -206,7 +206,7 @@ multi_line_output = 3 [tool.tox] legacy_tox_ini = """ [tox] -envlist = py{37,38,39,310,311}-{all,no}-extras,pypy3,omegaconf +envlist = py{37,38,39,310,311}-{all,no}-extras,omegaconf skip_missing_interpreters = true [testenv]