diff --git a/.circleci/config.yml b/.circleci/config.yml index dee92b65..9454e771 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,6 +32,9 @@ jobs: python3 -m jsonargparse_tests coverage xml coverage_py$py.xml pip3 install $(ls ./dist/*.whl)[test,all] python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml + tests_dir=$(python3 -c "print(__import__('jsonargparse_tests').__file__)" | sed 's|[^/]*$||') + sed -i '/^from __future__ import annotations$/d' $tests_dir/test_*.py + python3 -m jsonargparse_tests coverage xml coverage_py${py}_types.xml - persist_to_workspace: root: . paths: @@ -53,9 +56,28 @@ jobs: docker: - image: cimg/python:3.7 test-py36: - <<: *test-py38 docker: - image: cimg/python:3.6 + steps: + - attach_workspace: + at: . + - run: + name: Run unit tests + command: | + py=$(python3 --version | sed -r 's|.* 3\.([0-9]+)\..*|3.\1|') + virtualenv -p python3 venv$py + . venv$py/bin/activate + pip3 install $(ls ./dist/*.whl)[test-no-urls] + tests_dir=$(python3 -c "print(__import__('jsonargparse_tests').__file__)" | sed 's|[^/]*$||') + rm $tests_dir/test_backports.py; + sed -i '/^from __future__ import annotations$/d' $tests_dir/test_*.py + python3 -m jsonargparse_tests coverage xml coverage_py$py.xml + pip3 install $(ls ./dist/*.whl)[test,all] + python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml + - persist_to_workspace: + root: . + paths: + - ./coverage_*.xml codecov: docker: - image: cimg/python:3.8 @@ -66,18 +88,16 @@ jobs: - run: name: Code coverage command: | - #for py in 3.6 3.7 3.8 3.9 3.10 3.11; do - for py in 3.6 3.7 3.8 3.9 3.10; do - bash <(curl -s https://codecov.io/bash) \ - -Z \ - -t $CODECOV_TOKEN_JSONARGPARSE \ - -F py$py \ - -f coverage_py${py}.xml - bash <(curl -s https://codecov.io/bash) \ - -Z \ - -t $CODECOV_TOKEN_JSONARGPARSE \ - -F py${py}_all \ - -f coverage_py${py}_all.xml + for py in 3.6 3.7 3.8 3.9 3.10 3.11; do + for suffix in "" "_all" "_types"; do + if [ -f coverage_py${py}${suffix}.xml ]; then + bash <(curl -s https://codecov.io/bash) \ + -Z \ + -t $CODECOV_TOKEN_JSONARGPARSE \ + -F py${py}${suffix} \ + -f coverage_py${py}${suffix}.xml + fi + done done publish-pypi: docker: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ed0701c..5f958a63 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,8 @@ ci: skip: - mypy - tox + - test-py36 + - test-without-future-annotations - coverage autofix_prs: true autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' @@ -101,6 +103,42 @@ repos: pass_filenames: false verbose: true + - id: test-py36 + name: test-py36 + entry: bash -c ' + if [ "$(which python3.6)" = "" ]; then + echo "$(tput setaf 6) Skipped, python 3.6 not found $(tput sgr0)"; + else + TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX); + cleanup () { rm -rf "$TEST_DIR"; }; + trap cleanup EXIT; + ./setup.py bdist_wheel; + python3.6 -m venv "$TEST_DIR/venv36"; + . "$TEST_DIR/venv36/bin/activate"; + pip3 install "$(ls ./dist/*.whl | tail -n 1)[all,test,test-no-urls]"; + pip3 install pytest pytest-subtests; + cp jsonargparse_tests/*.py "$TEST_DIR"; + cd "$TEST_DIR"; + rm test_backports.py; + sed -i "/^from __future__ import annotations$/d" *.py; + pytest; + fi' + language: system + pass_filenames: false + + - id: test-without-future-annotations + name: test-without-future-annotations + entry: bash -c ' + TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX); + cleanup () { rm -rf "$TEST_DIR"; }; + trap cleanup EXIT; + cp jsonargparse_tests/*.py "$TEST_DIR"; + cd "$TEST_DIR"; + sed -i "/^from __future__ import annotations$/d" *.py; + pytest $TEST_DIR;' + language: system + pass_filenames: false + - id: doctest name: sphinx-build -M doctest sphinx sphinx/_build sphinx/index.rst entry: bash -c ' diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e9b416d2..5e2a81a7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,13 @@ Added - ``class_from_function`` now supports ``func_return`` parameter to specify the return type of the function (`lightning-flash#1564 comment `__). +- Support for postponed evaluation of annotations PEP `563 + `__ ``from __future__ import annotations`` + (`#120 `__). +- Backport types in python<=3.9 to support PEP `585 + `__ and `604 + `__ for postponed evaluation of annotations + (`#120 `__). Fixed ^^^^^ diff --git a/README.rst b/README.rst index d58d9ee5..b4bdb0f5 100644 --- a/README.rst +++ b/README.rst @@ -425,6 +425,13 @@ Some notes about this support are: nesting it is meant child types inside ``List``, ``Dict``, etc. There is no limit in nesting depth. +- Postponed evaluation of types PEP `563 `__ + (i.e. ``from __future__ import annotations``) is supported. Also supported on + ``python<=3.9`` are PEP `585 `__ (i.e. + ``list[], dict[], ...`` instead of ``List[], Dict[], + ...``) and `604 `__ (i.e. `` | + `` instead of ``Union[, ]``). + - Fully supported types are: ``str``, ``bool`` (more details in :ref:`boolean-arguments`), ``int``, ``float``, ``complex``, ``bytes``/``bytearray`` (Base64 encoding), ``List`` (more details in diff --git a/jsonargparse/_backports.py b/jsonargparse/_backports.py index 0e952ddf..8bd54fde 100644 --- a/jsonargparse/_backports.py +++ b/jsonargparse/_backports.py @@ -1,7 +1,17 @@ import ast +import inspect +import logging +import sys +import textwrap from collections import namedtuple from copy import deepcopy -from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union +from importlib import import_module +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union, get_type_hints + +if sys.version_info[:2] > (3, 6): + from typing import ForwardRef + +from ._optionals import typing_extensions_import var_map = namedtuple("var_map", "name value") none_map = var_map(name="NoneType", value=type(None)) @@ -17,8 +27,6 @@ class BackportTypeHints(ast.NodeTransformer): - _typing = __import__("typing") - def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: if isinstance(node.value, ast.Name) and node.value.id in pep585_map: value = self.new_name_load(pep585_map[node.value.id]) @@ -64,10 +72,142 @@ def new_name_load(self, var: var_map) -> ast.Name: return ast.Name(id=name, ctx=ast.Load()) def backport(self, input_ast: ast.AST, exec_vars: dict) -> ast.AST: + typing = __import__("typing") for key, value in exec_vars.items(): if getattr(value, "__module__", "") == "collections.abc": - if hasattr(self._typing, key): - exec_vars[key] = getattr(self._typing, key) + if hasattr(typing, key): + exec_vars[key] = getattr(typing, key) self.exec_vars = exec_vars backport_ast = self.visit(deepcopy(input_ast)) return ast.fix_missing_locations(backport_ast) + + +class NamesVisitor(ast.NodeVisitor): + def visit_Name(self, node: ast.Name) -> None: + self.names_found.append(node.id) + + def find(self, node: ast.AST) -> list: + from ._util import unique + + self.names_found: List[str] = [] + self.visit(node) + self.names_found = unique(self.names_found) + return self.names_found + + +def get_arg_type(arg_ast, aliases): + type_ast = ast.parse("___arg_type___ = 0") + type_ast.body[0].value = arg_ast + exec_vars = {} + bad_aliases = {} + add_asts = False + for name in NamesVisitor().find(arg_ast): + value = aliases[name] + if isinstance(value, tuple): + value = value[1] + if isinstance(value, Exception): + bad_aliases[name] = value + elif isinstance(value, ast.AST): + add_asts = True + else: + exec_vars[name] = value + if add_asts: + body = [] + for name, (_, value) in aliases.items(): + if isinstance(value, ast.AST): + body.append(ast.fix_missing_locations(value)) + elif not isinstance(value, Exception): + exec_vars[name] = value + type_ast.body = body + type_ast.body + if "TypeAlias" not in exec_vars: + type_alias = typing_extensions_import("TypeAlias") + if type_alias: + exec_vars["TypeAlias"] = type_alias + if sys.version_info < (3, 10): + backporter = BackportTypeHints() + type_ast = backporter.backport(type_ast, exec_vars) + try: + exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars) + except NameError as ex: + ex_from = None + for name, alias_exception in bad_aliases.items(): + if str(ex) == f"name '{name}' is not defined": + ex_from = alias_exception + break + raise ex from ex_from + arg_type = exec_vars["___arg_type___"] + if isinstance(arg_type, str) and arg_type in aliases: + arg_type = aliases[arg_type] + return arg_type + + +def type_requires_eval(typehint): + return isinstance(typehint, (str, ForwardRef)) + + +def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict: + global_vars = vars(import_module(obj.__module__)) + try: + types = get_type_hints(obj, global_vars) + except Exception as ex1: + types = ex1 # type: ignore + + if isinstance(types, dict) and all(not type_requires_eval(t) for t in types.values()): + return types + + try: + source = textwrap.dedent(inspect.getsource(obj)) + tree = ast.parse(source) + assert isinstance(tree, ast.Module) and len(tree.body) == 1 + node = tree.body[0] + assert isinstance(node, ast.FunctionDef) + except Exception as ex2: + if isinstance(types, Exception): + if logger: + logger.debug(f"Failed to parse to source code for {obj}", exc_info=ex2) + raise type(types)(f"{repr(types)} + {repr(ex2)}") from ex2 # type: ignore + return types + + aliases = __builtins__.copy() # type: ignore + aliases.update(global_vars) + ex = None + if isinstance(types, Exception): + ex = types + types = {} + + for arg_ast in node.args.args + node.args.kwonlyargs: + name = arg_ast.arg + if not arg_ast.annotation or (name in types and not type_requires_eval(types[name])): + continue + try: + if isinstance(arg_ast.annotation, ast.Constant) and arg_ast.annotation.value in aliases: + types[name] = aliases[arg_ast.annotation.value] + else: + types[name] = get_arg_type(arg_ast.annotation, aliases) + except Exception as ex3: + types[name] = ex3 + + if all(isinstance(t, Exception) for t in types.values()): + raise ex or next(types.values()) # type: ignore + + return types + + +def evaluate_postponed_annotations(params, component, logger): + if sys.version_info[:2] == (3, 6) or not (params and any(type_requires_eval(p.annotation) for p in params)): + return + try: + if sys.version_info < (3, 10): + types = get_types(component, logger) + else: + types = get_type_hints(component) + except Exception as ex: + logger.debug(f"Unable to evaluate types for {component}", exc_info=ex) + return + for param in params: + if param.name in types: + param_type = types[param.name] + if isinstance(param_type, Exception): + logger.debug(f"Unable to evaluate type of {param.name} from {component}", exc_info=param_type) + continue + param.annotation = param_type diff --git a/jsonargparse/_link_arguments.py b/jsonargparse/_link_arguments.py index 9572046c..88cf8e0a 100644 --- a/jsonargparse/_link_arguments.py +++ b/jsonargparse/_link_arguments.py @@ -1,6 +1,5 @@ """Code related to argument linking.""" -import inspect import re from argparse import SUPPRESS from argparse import Action as ArgparseAction @@ -18,6 +17,7 @@ filter_default_actions, ) from ._namespace import Namespace, split_key_leaf +from ._parameter_resolvers import get_signature_parameters from ._type_checking import ArgumentParser, _ArgumentGroup __all__ = ["ArgumentLinking"] @@ -275,7 +275,7 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None: value = value.as_dict() else: # Automatic namespace to dict based on compute_fn param type hint - params = list(inspect.signature(action.compute_fn).parameters.values()) + params = get_signature_parameters(action.compute_fn) for n, param in enumerate(params): if ( n < len(args) diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py index 92a9690f..461c0b03 100644 --- a/jsonargparse/_parameter_resolvers.py +++ b/jsonargparse/_parameter_resolvers.py @@ -10,6 +10,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from ._backports import evaluate_postponed_annotations from ._common import is_dataclass_like, is_subclass from ._optionals import parse_docs from ._stubs_resolver import get_stub_types @@ -275,6 +276,7 @@ def get_signature_parameters_and_indexes(component, parent, logger): component=component, **{a: getattr(param, a) for a in parameter_attributes}, ) + evaluate_postponed_annotations(params, signature_source, logger) stubs = get_stub_types(params, signature_source, parent, logger) return params, args_idx, kwargs_idx, doc_params, stubs @@ -824,6 +826,7 @@ def get_parameters_from_pydantic_or_attrs( component=function_or_class, ) ) + evaluate_postponed_annotations(params, function_or_class, logger) return params diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index 9dd19952..edcadb8d 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -15,7 +15,12 @@ get_parameter_origins, get_signature_parameters, ) -from ._typehints import ActionTypeHint, LazyInitBaseClass, is_optional +from ._typehints import ( + ActionTypeHint, + LazyInitBaseClass, + callable_instances, + is_optional, +) from ._util import LoggerProperty, get_import_path, iter_to_set_str from .typing import register_pydantic_type @@ -176,9 +181,14 @@ def add_function_arguments( if not callable(function): raise ValueError('Expected "function" argument to be a callable object.') + method_name = None + if hasattr(function, "__class__") and callable_instances(function.__class__): + function = function.__class__ + method_name = "__call__" + return self._add_signature_arguments( function, - None, + method_name, nested_key, as_group, as_positional, @@ -366,12 +376,10 @@ def _add_signature_parameter( action.sub_add_kwargs["skip"] = subclass_skip added_args.append(dest) elif is_required and fail_untyped: - msg = f'With fail_untyped=True, all mandatory parameters must have a supported type. Parameter "{name}" from "{src}" ' - if isinstance(annotation, str): - msg += "specifies the type as a string. Types as a string and `from __future__ import annotations` is currently not supported." - else: - msg += "does not specify a type." - raise ValueError(msg) + raise ValueError( + "With fail_untyped=True, all mandatory parameters must have a supported" + f" type. Parameter '{name}' from '{src}' does not specify a type." + ) def add_dataclass_arguments( self, diff --git a/jsonargparse/_stubs_resolver.py b/jsonargparse/_stubs_resolver.py index 1211f85c..1af5bdff 100644 --- a/jsonargparse/_stubs_resolver.py +++ b/jsonargparse/_stubs_resolver.py @@ -5,14 +5,10 @@ from copy import deepcopy from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple -from ._optionals import ( - import_typeshed_client, - typeshed_client_support, - typing_extensions_import, -) -from ._util import unique +from ._backports import NamesVisitor, get_arg_type +from ._optionals import import_typeshed_client, typeshed_client_support if TYPE_CHECKING: # pragma: no cover import typeshed_client as tc @@ -32,17 +28,6 @@ def import_module_or_none(path: str): return None -class NamesVisitor(ast.NodeVisitor): - def visit_Name(self, node: ast.Name) -> None: - self.names_found.append(node.id) - - def find(self, node: ast.AST) -> list: - self.names_found: List[str] = [] - self.visit(node) - self.names_found = unique(self.names_found) - return self.names_found - - class ImportsVisitor(ast.NodeVisitor): def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if node.level: @@ -262,49 +247,6 @@ def alias_is_unique(aliases, name, source, value): return True -def get_arg_type(arg_ast, aliases): - type_ast = ast.parse("___arg_type___ = 0") - type_ast.body[0].value = arg_ast.annotation - exec_vars = {} - bad_aliases = {} - add_asts = False - for name in NamesVisitor().find(arg_ast.annotation): - _, value = aliases[name] - if isinstance(value, Exception): - bad_aliases[name] = value - elif isinstance(value, ast.AST): - add_asts = True - else: - exec_vars[name] = value - if add_asts: - body = [] - for name, (_, value) in aliases.items(): - if isinstance(value, ast.AST): - body.append(ast.fix_missing_locations(value)) - elif not isinstance(value, Exception): - exec_vars[name] = value - type_ast.body = body + type_ast.body - if "TypeAlias" not in exec_vars: - type_alias = typing_extensions_import("TypeAlias") - if type_alias: - exec_vars["TypeAlias"] = type_alias - if sys.version_info < (3, 10): - from ._backports import BackportTypeHints - - backporter = BackportTypeHints() - type_ast = backporter.backport(type_ast, exec_vars) - try: - exec(compile(type_ast, filename="", mode="exec"), exec_vars, exec_vars) - except NameError as ex: - ex_from = None - for name, alias_exception in bad_aliases.items(): - if str(ex) == f"name '{name}' is not defined": - ex_from = alias_exception - break - raise ex from ex_from - return exec_vars["___arg_type___"] - - def get_stub_types(params, component, parent, logger) -> Optional[Dict[str, Any]]: if not typeshed_client_support: return None @@ -327,7 +269,7 @@ def get_stub_types(params, component, parent, logger) -> Optional[Dict[str, Any] name = arg_ast.arg if arg_ast.annotation and (name in missing_types or name not in known_params): try: - types[name] = get_arg_type(arg_ast, aliases) + types[name] = get_arg_type(arg_ast.annotation, aliases) except Exception as ex: logger.debug( f"Failed to parse type stub for {component.__qualname__!r} parameter {name!r}", exc_info=ex diff --git a/jsonargparse/_util.py b/jsonargparse/_util.py index 7130bd53..b6a98ba2 100644 --- a/jsonargparse/_util.py +++ b/jsonargparse/_util.py @@ -28,6 +28,7 @@ Type, TypeVar, Union, + get_type_hints, ) from ._common import parser_capture, parser_context @@ -375,10 +376,13 @@ def class_from_function( if func_return is None: func_return = inspect.signature(func).return_annotation if isinstance(func_return, str): - caller_frame = inspect.currentframe().f_back # type: ignore - func_return = caller_frame.f_locals.get(func_return) or caller_frame.f_globals.get(func_return) # type: ignore - if func_return is None: - raise ValueError(f"Unable to dereference {func_return} the return type of {func}.") + try: + func_return = get_type_hints(func)["return"] + if sys.version_info[:2] != (3, 6) and isinstance(func_return, __import__("typing").ForwardRef): + func_return = func_return._evaluate(func.__globals__, {}) + except Exception as ex: + func_return = inspect.signature(func).return_annotation + raise ValueError(f"Unable to dereference {func_return}, the return type of {func}: {ex}") from ex @wraps(func) def __new__(cls, *args, **kwargs): diff --git a/jsonargparse_tests/conftest.py b/jsonargparse_tests/conftest.py index 656efe33..89c79d9b 100644 --- a/jsonargparse_tests/conftest.py +++ b/jsonargparse_tests/conftest.py @@ -131,6 +131,12 @@ def capture_logs(logger: logging.Logger) -> Iterator[StringIO]: yield captured +@contextmanager +def source_unavailable(): + with patch("inspect.getsource", side_effect=OSError("could not get source code")): + yield + + def get_parser_help(parser: ArgumentParser) -> str: out = StringIO() with patch.dict(os.environ, {"COLUMNS": "200"}): diff --git a/jsonargparse_tests/test_actions.py b/jsonargparse_tests/test_actions.py index e2ca6f1d..72421779 100644 --- a/jsonargparse_tests/test_actions.py +++ b/jsonargparse_tests/test_actions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import pytest diff --git a/jsonargparse_tests/test_argcomplete.py b/jsonargparse_tests/test_argcomplete.py index 68915888..c6fbd545 100644 --- a/jsonargparse_tests/test_argcomplete.py +++ b/jsonargparse_tests/test_argcomplete.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys from contextlib import ExitStack, contextmanager diff --git a/jsonargparse_tests/test_backports.py b/jsonargparse_tests/test_backports.py new file mode 100644 index 00000000..07854c22 --- /dev/null +++ b/jsonargparse_tests/test_backports.py @@ -0,0 +1,123 @@ +from __future__ import annotations # keep + +import sys +from random import Random +from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union + +import pytest + +from jsonargparse._backports import get_types +from jsonargparse._parameter_resolvers import get_signature_parameters as get_params +from jsonargparse_tests.conftest import capture_logs, source_unavailable + + +@pytest.fixture(autouse=True) +def skip_if_python_older_than_3_10(): + if sys.version_info >= (3, 10, 0): + pytest.skip("python<3.10 is required") + + +def function_pep585_dict(p1: dict[str, int], p2: dict[int, str] = {1: "a"}): + return p1 + + +def function_pep585_list(p1: list[str], p2: list[float] = [0.1, 2.3]): + return p1 + + +def function_pep585_set(p1: set[str], p2: set[int] = {1, 2}): + return p1 + + +def function_pep585_frozenset(p1: frozenset[str], p2: frozenset[int] = frozenset(range(3))): + return p1 + + +def function_pep585_tuple(p1: tuple[str, float], p2: tuple[int, ...] = (1, 2)): + return p1 + + +def function_pep585_type(p1: type[Random], p2: type[Random] = Random): + return p1 + + +@pytest.mark.skipif(sys.version_info >= (3, 9, 0), reason="python<3.9 is required") +@pytest.mark.parametrize( + ["function", "expected"], + [ + (function_pep585_dict, {"p1": Dict[str, int], "p2": Dict[int, str]}), + (function_pep585_list, {"p1": List[str], "p2": List[float]}), + (function_pep585_set, {"p1": Set[str], "p2": Set[int]}), + (function_pep585_frozenset, {"p1": FrozenSet[str], "p2": FrozenSet[int]}), + (function_pep585_tuple, {"p1": Tuple[str, float], "p2": Tuple[int, ...]}), + (function_pep585_type, {"p1": Type[Random], "p2": Type[Random]}), + ], +) +def test_get_types_pep585(function, expected): + types = get_types(function) + assert types == expected + + +def function_pep604(p1: str | None, p2: int | float | bool = 1): + return p1 + + +def test_get_types_pep604(): + types = get_types(function_pep604) + assert types == {"p1": Union[str, None], "p2": Union[int, float, bool]} + + +def test_get_types_source_unavailable(logger): + with source_unavailable(), pytest.raises(TypeError) as ctx, capture_logs(logger) as logs: + get_types(function_pep604, logger) + ctx.match("could not get source code") + assert "Failed to parse to source code" in logs.getvalue() + + +class NeedsBackport: + def __init__(self, p1: list | set): + self.p1 = p1 + + @staticmethod + def static_method(p1: str | int): + return p1 + + @classmethod + def class_method(cls, p1: float | None): + return p1 + + +@pytest.mark.parametrize( + ["method", "expected"], + [ + (NeedsBackport.__init__, {"p1": Union[list, set]}), + (NeedsBackport.static_method, {"p1": Union[str, int]}), + (NeedsBackport.class_method, {"p1": Union[float, None]}), + ], +) +def test_get_types_methods(method, expected): + types = get_types(method) + assert types == expected + + +def function_forward_ref(cls: "NeedsBackport", p1: "int"): + return cls + + +def test_get_types_forward_ref(): + types = get_types(function_forward_ref) + assert types == {"cls": NeedsBackport, "p1": int} + + +def function_undefined_type(p1: not_defined | None, p2: int): # type: ignore # noqa: F821 + return p1 + + +def test_get_types_undefined_type(): + types = get_types(function_undefined_type) + assert types["p2"] is int + assert isinstance(types["p1"], KeyError) + assert "not_defined" in str(types["p1"]) + + params = get_params(function_undefined_type) + assert params[0].annotation == "not_defined | None" diff --git a/jsonargparse_tests/test_cli.py b/jsonargparse_tests/test_cli.py index d6a123c3..77abd84e 100644 --- a/jsonargparse_tests/test_cli.py +++ b/jsonargparse_tests/test_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys from contextlib import redirect_stderr, redirect_stdout, suppress diff --git a/jsonargparse_tests/test_core.py b/jsonargparse_tests/test_core.py index daf05788..9b3ac003 100644 --- a/jsonargparse_tests/test_core.py +++ b/jsonargparse_tests/test_core.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import pickle diff --git a/jsonargparse_tests/test_dataclass_like.py b/jsonargparse_tests/test_dataclass_like.py index 68d84c07..75da2897 100644 --- a/jsonargparse_tests/test_dataclass_like.py +++ b/jsonargparse_tests/test_dataclass_like.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import sys from typing import Any, Dict, List, Optional, Union @@ -421,7 +423,7 @@ def none(x): return x -@pytest.mark.skipif(sys.version_info == (3, 6), reason="pydantic not supported in python 3.6") +@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="pydantic not supported in python 3.6") @pytest.mark.skipif(not pydantic_support, reason="pydantic package is required") class TestPydantic: def test_dataclass(self, parser): @@ -456,21 +458,22 @@ def test_field_description(self, parser): assert "p1 help (required, type: str)" in help_str assert "p2 help (type: int, default: 2)" in help_str - def test_pydantic_types(self, subtests): - for valid_value, invalid_value, cast, pydantic_type in [ - ("abc", "a", none, pydantic.constr(min_length=2, max_length=4)), - (2, 0, none, pydantic.conint(ge=1)), - (-1.0, 1.0, none, pydantic.confloat(lt=0.0)), - ([1], [], none, pydantic.conlist(int, min_items=1)), - ([], [3, 4], none, pydantic.conlist(int, max_items=1)), - ([1], "x", list, pydantic.conset(int, min_items=1)), - ("http://abc.es", "-", none, pydantic.HttpUrl), - ("127.0.0.1", "0", str, pydantic.IPvAnyAddress), - ]: + def test_pydantic_types(self, subtests, monkeypatch): + for num, (valid_value, invalid_value, cast, pydantic_type) in enumerate( + [ + ("abc", "a", none, pydantic.constr(min_length=2, max_length=4)), + (2, 0, none, pydantic.conint(ge=1)), + (-1.0, 1.0, none, pydantic.confloat(lt=0.0)), + ([1], [], none, pydantic.conlist(int, min_items=1)), + ([], [3, 4], none, pydantic.conlist(int, max_items=1)), + ([1], "x", list, pydantic.conset(int, min_items=1)), + ("http://abc.es", "-", none, pydantic.HttpUrl), + ("127.0.0.1", "0", str, pydantic.IPvAnyAddress), + ] + ): with subtests.test(f"type={pydantic_type.__name__} valid={valid_value} invalid={invalid_value}"): - - class Model(pydantic.BaseModel): - param: pydantic_type + Model = pydantic.create_model(f"Model{num}", param=(pydantic_type, ...)) + monkeypatch.setitem(Model.__init__.__globals__, "pydantic_type", pydantic_type) parser = ArgumentParser(exit_on_error=False) parser.add_argument("--model", type=Model) diff --git a/jsonargparse_tests/test_deprecated.py b/jsonargparse_tests/test_deprecated.py index b2926c36..ae74e202 100644 --- a/jsonargparse_tests/test_deprecated.py +++ b/jsonargparse_tests/test_deprecated.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pathlib import sys diff --git a/jsonargparse_tests/test_formatters.py b/jsonargparse_tests/test_formatters.py index 848e86fe..cf9c70a5 100644 --- a/jsonargparse_tests/test_formatters.py +++ b/jsonargparse_tests/test_formatters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path from typing import Tuple diff --git a/jsonargparse_tests/test_jsonnet.py b/jsonargparse_tests/test_jsonnet.py index 59c86e5d..910d5151 100644 --- a/jsonargparse_tests/test_jsonnet.py +++ b/jsonargparse_tests/test_jsonnet.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re from pathlib import Path diff --git a/jsonargparse_tests/test_jsonschema.py b/jsonargparse_tests/test_jsonschema.py index f43fef32..fc756dea 100644 --- a/jsonargparse_tests/test_jsonschema.py +++ b/jsonargparse_tests/test_jsonschema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import re from importlib.util import find_spec diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py index 04f8a2fc..defa06e6 100644 --- a/jsonargparse_tests/test_link_arguments.py +++ b/jsonargparse_tests/test_link_arguments.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from calendar import Calendar, TextCalendar from typing import Any, List, Mapping, Optional, Union diff --git a/jsonargparse_tests/test_loaders_dumpers.py b/jsonargparse_tests/test_loaders_dumpers.py index 24f42faa..b403b22b 100644 --- a/jsonargparse_tests/test_loaders_dumpers.py +++ b/jsonargparse_tests/test_loaders_dumpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from typing import List from unittest.mock import patch diff --git a/jsonargparse_tests/test_namespace.py b/jsonargparse_tests/test_namespace.py index 49880bf9..b3ccc12e 100644 --- a/jsonargparse_tests/test_namespace.py +++ b/jsonargparse_tests/test_namespace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import platform diff --git a/jsonargparse_tests/test_optionals.py b/jsonargparse_tests/test_optionals.py index b742639a..36bf1436 100644 --- a/jsonargparse_tests/test_optionals.py +++ b/jsonargparse_tests/test_optionals.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from jsonargparse import get_config_read_mode, set_config_read_mode diff --git a/jsonargparse_tests/test_parameter_resolvers.py b/jsonargparse_tests/test_parameter_resolvers.py index 1e601a82..ada632db 100644 --- a/jsonargparse_tests/test_parameter_resolvers.py +++ b/jsonargparse_tests/test_parameter_resolvers.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import calendar import inspect import xml.dom from calendar import Calendar -from contextlib import contextmanager from random import shuffle from typing import Any, Callable, Dict, List from unittest.mock import patch @@ -13,7 +14,7 @@ from jsonargparse._optionals import docstring_parser_support from jsonargparse._parameter_resolvers import get_signature_parameters as get_params from jsonargparse._parameter_resolvers import is_lambda -from jsonargparse_tests.conftest import capture_logs +from jsonargparse_tests.conftest import capture_logs, source_unavailable class ClassA: @@ -407,6 +408,13 @@ def function_constant_boolean(**kwargs): return function_with_kwargs(k1=False, **kwargs) +def function_invalid_type(param: "invalid:" = 1): # type: ignore # noqa: F722 + """ + Args: + param: help for param + """ + + def cond_1(kc: int = 1, kn0: str = "x", kn1: str = "-"): """ Args: @@ -439,12 +447,6 @@ def conditional_calls(**kwargs): cond_3(**kwargs) -@contextmanager -def source_unavailable(): - with patch("inspect.getsource", side_effect=OSError("could not get source code")): - yield - - def assert_params(params, expected, origins={}): assert expected == [p.name for p in params] docs = [f"help for {p.name}" for p in params] if docstring_parser_support else [None] * len(params) @@ -662,6 +664,13 @@ def test_get_params_function_constant_boolean(): assert get_params(function_constant_boolean) == [] +def test_get_params_function_invalid_type(logger): + with capture_logs(logger): + params = get_params(function_invalid_type, logger=logger) + assert_params(params, ["param"]) + assert params[0].annotation.replace("'", "") == "invalid:" + + def test_conditional_calls_kwargs(): assert_params( get_params(conditional_calls), diff --git a/jsonargparse_tests/test_paths.py b/jsonargparse_tests/test_paths.py index 63e1f446..7843e3d9 100644 --- a/jsonargparse_tests/test_paths.py +++ b/jsonargparse_tests/test_paths.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os from calendar import Calendar diff --git a/jsonargparse_tests/test_signatures.py b/jsonargparse_tests/test_signatures.py index bca27022..a280f393 100644 --- a/jsonargparse_tests/test_signatures.py +++ b/jsonargparse_tests/test_signatures.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch @@ -85,7 +88,7 @@ def test_add_class_failure_not_a_class(parser): def test_add_class_failure_positional_without_type(parser): with pytest.raises(ValueError) as ctx: parser.add_class_arguments(Class2) - ctx.match(f'Parameter "c2_a0" from "{__name__}.Class2.__init__" does not specify a type') + ctx.match(f"Parameter 'c2_a0' from '{__name__}.Class2.__init__' does not specify a type") def test_add_class_without_nesting(parser): @@ -251,6 +254,7 @@ def test_add_class_with_required_parameters(parser): assert cfg.model == Namespace(m=0.1, n=3) +@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="import forces future annotations in python 3.6") def test_add_class_conditional_kwargs(parser): from jsonargparse_tests.test_parameter_resolvers import ClassG @@ -510,24 +514,24 @@ def test_add_function_implicit_optional(parser): assert None is parser.parse_args(["--a1=null"]).a1 -def func_untyped_params(a1, a2=None): - return a1 +def func_type_as_string(a2: "int"): + return a2 -def test_add_function_fail_untyped_true_untyped_params(parser): - with pytest.raises(ValueError) as ctx: - parser.add_function_arguments(func_untyped_params, fail_untyped=True) - ctx.match('Parameter "a1" from .* does not specify a type') +@pytest.mark.skipif(sys.version_info[:2] == (3, 6), reason="not supported in python 3.6") +def test_add_function_fail_untyped_true_str_type(parser): + added_args = parser.add_function_arguments(func_type_as_string, fail_untyped=True) + assert ["a2"] == added_args -def func_type_as_string(a2: "int"): - return a2 +def func_untyped_params(a1, a2=None): + return a1 -def test_add_function_fail_untyped_true_str_type(parser): +def test_add_function_fail_untyped_true_untyped_params(parser): with pytest.raises(ValueError) as ctx: - parser.add_function_arguments(func_type_as_string, fail_untyped=True) - ctx.match('Parameter "a2" from .* specifies the type as a string') + parser.add_function_arguments(func_untyped_params, fail_untyped=True) + ctx.match("Parameter 'a1' from .* does not specify a type") def test_add_function_fail_untyped_false(parser): diff --git a/jsonargparse_tests/test_stubs_resolver.py b/jsonargparse_tests/test_stubs_resolver.py index ea4b5f4b..fed1419f 100644 --- a/jsonargparse_tests/test_stubs_resolver.py +++ b/jsonargparse_tests/test_stubs_resolver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import sys from calendar import Calendar, TextCalendar @@ -46,7 +48,7 @@ def get_param_names(params): class WithoutParent: - ... + pass @pytest.mark.parametrize( diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index 7b84fd79..c6fc4194 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import textwrap import warnings diff --git a/jsonargparse_tests/test_subcommands.py b/jsonargparse_tests/test_subcommands.py index e96af28a..3e4421cf 100644 --- a/jsonargparse_tests/test_subcommands.py +++ b/jsonargparse_tests/test_subcommands.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import warnings from pathlib import Path diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index fbdb5004..cd1717b1 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random import sys import time @@ -783,6 +785,12 @@ def test_action_typehint_unsupported_type(typehint): ctx.match("Unsupported type hint") +def test_action_typehint_none_type_error(): + with pytest.raises(ValueError) as ctx: + ActionTypeHint(typehint=None) + ctx.match("Expected typehint keyword argument") + + @pytest.mark.parametrize( ["typehint", "ref_type", "expected"], [ diff --git a/jsonargparse_tests/test_typing.py b/jsonargparse_tests/test_typing.py index d395cb31..d779b441 100644 --- a/jsonargparse_tests/test_typing.py +++ b/jsonargparse_tests/test_typing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import random import uuid diff --git a/jsonargparse_tests/test_util.py b/jsonargparse_tests/test_util.py index d129459b..90e86fd4 100644 --- a/jsonargparse_tests/test_util.py +++ b/jsonargparse_tests/test_util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import pathlib @@ -566,7 +568,7 @@ def get_foo(cls) -> "Foo": def closure_get_foo(): - def get_foo() -> "Foo": + def get_foo() -> Foo: return Foo() return get_foo @@ -593,7 +595,7 @@ def get_unknown() -> "Unknown": # type: ignore # noqa: F821 def test_invalid_class_from_function(): with pytest.raises(ValueError) as ctx: class_from_function(get_unknown) - ctx.match("Unable to dereference None the return type") + ctx.match("Unable to dereference '?Unknown'?, the return type of") def get_random_untyped(): diff --git a/pyproject.toml b/pyproject.toml index d23fa602..238837bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -202,7 +202,7 @@ exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)" [tool.tox] legacy_tox_ini = """ [tox] -envlist = py{37,38,39,310,311}-{all,no}-extras,pypy3,omegaconf +envlist = py{37,38,39,310,311}-{all,no}-extras,omegaconf skip_missing_interpreters = true [testenv]