Skip to content

Commit

Permalink
- Added support for postponed evaluation of annotations PEP 563 (#120).
Browse files Browse the repository at this point in the history
- Backport types in python<=3.9 to support PEP 585 and 604 for postponed evaluation of annotations (#120).
  • Loading branch information
mauvilsa committed Jun 22, 2023
1 parent 4dc2ff8 commit f36cb9e
Show file tree
Hide file tree
Showing 35 changed files with 509 additions and 133 deletions.
46 changes: 33 additions & 13 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
python3 -m jsonargparse_tests coverage xml coverage_py$py.xml
pip3 install $(ls ./dist/*.whl)[test,all]
python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml
tests_dir=$(python3 -c "print(__import__('jsonargparse_tests').__file__)" | sed 's|[^/]*$||')
sed -i '/^from __future__ import annotations$/d' $tests_dir/test_*.py
python3 -m jsonargparse_tests coverage xml coverage_py${py}_types.xml
- persist_to_workspace:
root: .
paths:
Expand All @@ -53,9 +56,28 @@ jobs:
docker:
- image: cimg/python:3.7
test-py36:
<<: *test-py38
docker:
- image: cimg/python:3.6
steps:
- attach_workspace:
at: .
- run:
name: Run unit tests
command: |
py=$(python3 --version | sed -r 's|.* 3\.([0-9]+)\..*|3.\1|')
virtualenv -p python3 venv$py
. venv$py/bin/activate
pip3 install $(ls ./dist/*.whl)[test-no-urls]
tests_dir=$(python3 -c "print(__import__('jsonargparse_tests').__file__)" | sed 's|[^/]*$||')
rm $tests_dir/test_backports.py;
sed -i '/^from __future__ import annotations$/d' $tests_dir/test_*.py
python3 -m jsonargparse_tests coverage xml coverage_py$py.xml
pip3 install $(ls ./dist/*.whl)[test,all]
python3 -m jsonargparse_tests coverage xml coverage_py${py}_all.xml
- persist_to_workspace:
root: .
paths:
- ./coverage_*.xml
codecov:
docker:
- image: cimg/python:3.8
Expand All @@ -66,18 +88,16 @@ jobs:
- run:
name: Code coverage
command: |
#for py in 3.6 3.7 3.8 3.9 3.10 3.11; do
for py in 3.6 3.7 3.8 3.9 3.10; do
bash <(curl -s https://codecov.io/bash) \
-Z \
-t $CODECOV_TOKEN_JSONARGPARSE \
-F py$py \
-f coverage_py${py}.xml
bash <(curl -s https://codecov.io/bash) \
-Z \
-t $CODECOV_TOKEN_JSONARGPARSE \
-F py${py}_all \
-f coverage_py${py}_all.xml
for py in 3.6 3.7 3.8 3.9 3.10 3.11; do
for suffix in "" "_all" "_types"; do
if [ -f coverage_py${py}${suffix}.xml ]; then
bash <(curl -s https://codecov.io/bash) \
-Z \
-t $CODECOV_TOKEN_JSONARGPARSE \
-F py${py}${suffix} \
-f coverage_py${py}${suffix}.xml
fi
done
done
publish-pypi:
docker:
Expand Down
38 changes: 38 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ ci:
skip:
- mypy
- tox
- test-py36
- test-without-future-annotations
- coverage
autofix_prs: true
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
Expand Down Expand Up @@ -101,6 +103,42 @@ repos:
pass_filenames: false
verbose: true

- id: test-py36
name: test-py36
entry: bash -c '
if [ "$(which python3.6)" = "" ]; then
echo "$(tput setaf 6) Skipped, python 3.6 not found $(tput sgr0)";
else
TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX);
cleanup () { rm -rf "$TEST_DIR"; };
trap cleanup EXIT;
./setup.py bdist_wheel;
python3.6 -m venv "$TEST_DIR/venv36";
. "$TEST_DIR/venv36/bin/activate";
pip3 install "$(ls ./dist/*.whl | tail -n 1)[all,test,test-no-urls]";
pip3 install pytest pytest-subtests;
cp jsonargparse_tests/*.py "$TEST_DIR";
cd "$TEST_DIR";
rm test_backports.py;
sed -i "/^from __future__ import annotations$/d" *.py;
pytest;
fi'
language: system
pass_filenames: false

- id: test-without-future-annotations
name: test-without-future-annotations
entry: bash -c '
TEST_DIR=$(mktemp -d -t _jsonargparse_tests_XXXXXX);
cleanup () { rm -rf "$TEST_DIR"; };
trap cleanup EXIT;
cp jsonargparse_tests/*.py "$TEST_DIR";
cd "$TEST_DIR";
sed -i "/^from __future__ import annotations$/d" *.py;
pytest $TEST_DIR;'
language: system
pass_filenames: false

- id: doctest
name: sphinx-build -M doctest sphinx sphinx/_build sphinx/index.rst
entry: bash -c '
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ Added
- ``class_from_function`` now supports ``func_return`` parameter to specify the
return type of the function (`lightning-flash#1564 comment
<https://github.com/Lightning-Universe/lightning-flash/pull/1564#discussion_r1218147330>`__).
- Support for postponed evaluation of annotations PEP `563
<https://peps.python.org/pep-0563/>`__ ``from __future__ import annotations``
(`#120 <https://github.com/omni-us/jsonargparse/issues/120>`__).
- Backport types in python<=3.9 to support PEP `585
<https://peps.python.org/pep-0585/>`__ and `604
<https://peps.python.org/pep-0604/>`__ for postponed evaluation of annotations
(`#120 <https://github.com/omni-us/jsonargparse/issues/120>`__).

Fixed
^^^^^
Expand Down
7 changes: 7 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,13 @@ Some notes about this support are:
nesting it is meant child types inside ``List``, ``Dict``, etc. There is no
limit in nesting depth.

- Postponed evaluation of types PEP `563 <https://peps.python.org/pep-0563/>`__
(i.e. ``from __future__ import annotations``) is supported. Also supported on
``python<=3.9`` are PEP `585 <https://peps.python.org/pep-0585/>`__ (i.e.
``list[<type>], dict[<type>], ...`` instead of ``List[<type>], Dict[<type>],
...``) and `604 <https://peps.python.org/pep-0604/>`__ (i.e. ``<type> |
<type>`` instead of ``Union[<type>, <type>]``).

- Fully supported types are: ``str``, ``bool`` (more details in
:ref:`boolean-arguments`), ``int``, ``float``, ``complex``,
``bytes``/``bytearray`` (Base64 encoding), ``List`` (more details in
Expand Down
149 changes: 144 additions & 5 deletions jsonargparse/_backports.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import ast
import inspect
import logging
import sys
import textwrap
from collections import namedtuple
from copy import deepcopy
from typing import Dict, FrozenSet, List, Set, Tuple, Type, Union
from importlib import import_module
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Type, Union, get_type_hints

if sys.version_info[:2] > (3, 6):
from typing import ForwardRef

from ._optionals import typing_extensions_import

var_map = namedtuple("var_map", "name value")
none_map = var_map(name="NoneType", value=type(None))
Expand All @@ -17,8 +27,6 @@


class BackportTypeHints(ast.NodeTransformer):
_typing = __import__("typing")

def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript:
if isinstance(node.value, ast.Name) and node.value.id in pep585_map:
value = self.new_name_load(pep585_map[node.value.id])
Expand Down Expand Up @@ -64,10 +72,141 @@ def new_name_load(self, var: var_map) -> ast.Name:
return ast.Name(id=name, ctx=ast.Load())

def backport(self, input_ast: ast.AST, exec_vars: dict) -> ast.AST:
typing = __import__("typing")
for key, value in exec_vars.items():
if getattr(value, "__module__", "") == "collections.abc":
if hasattr(self._typing, key):
exec_vars[key] = getattr(self._typing, key)
if hasattr(typing, key):
exec_vars[key] = getattr(typing, key)
self.exec_vars = exec_vars
backport_ast = self.visit(deepcopy(input_ast))
return ast.fix_missing_locations(backport_ast)


class NamesVisitor(ast.NodeVisitor):
def visit_Name(self, node: ast.Name) -> None:
self.names_found.append(node.id)

def find(self, node: ast.AST) -> list:
from ._util import unique

self.names_found: List[str] = []
self.visit(node)
self.names_found = unique(self.names_found)
return self.names_found


def get_arg_type(arg_ast, aliases):
type_ast = ast.parse("___arg_type___ = 0")
type_ast.body[0].value = arg_ast
exec_vars = {}
bad_aliases = {}
add_asts = False
for name in NamesVisitor().find(arg_ast):
value = aliases[name]
if isinstance(value, tuple):
value = value[1]
if isinstance(value, Exception):
bad_aliases[name] = value
elif isinstance(value, ast.AST):
add_asts = True
else:
exec_vars[name] = value
if add_asts:
body = []
for name, (_, value) in aliases.items():
if isinstance(value, ast.AST):
body.append(ast.fix_missing_locations(value))
elif not isinstance(value, Exception):
exec_vars[name] = value
type_ast.body = body + type_ast.body
if "TypeAlias" not in exec_vars:
type_alias = typing_extensions_import("TypeAlias")
if type_alias:
exec_vars["TypeAlias"] = type_alias
if sys.version_info < (3, 10):
backporter = BackportTypeHints()
type_ast = backporter.backport(type_ast, exec_vars)
try:
exec(compile(type_ast, filename="<ast>", mode="exec"), exec_vars, exec_vars)
except NameError as ex:
ex_from = None
for name, alias_exception in bad_aliases.items():
if str(ex) == f"name '{name}' is not defined":
ex_from = alias_exception
break
raise ex from ex_from
arg_type = exec_vars["___arg_type___"]
if isinstance(arg_type, str) and arg_type in aliases:
arg_type = aliases[arg_type]
return arg_type


def type_requires_eval(typehint):
return isinstance(typehint, (str, ForwardRef))


def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
global_vars = vars(import_module(obj.__module__))
try:
types = get_type_hints(obj, global_vars)
except Exception as ex1:
types = ex1 # type: ignore

if isinstance(types, dict) and all(not type_requires_eval(t) for t in types.values()):
return types

try:
source = textwrap.dedent(inspect.getsource(obj))
tree = ast.parse(source)
assert isinstance(tree, ast.Module) and len(tree.body) == 1
node = tree.body[0]
assert isinstance(node, ast.FunctionDef)
except Exception as ex2:
if isinstance(types, Exception):
if logger:
logger.debug(f"Failed to parse to source code for {obj}", exc_info=ex2)
raise type(types)(f"{repr(types)} + {repr(ex2)}") from ex2 # type: ignore
return types

aliases = __builtins__.copy() # type: ignore
aliases.update(global_vars)
ex = None
if isinstance(types, Exception):
ex = types
types = {}

for arg_ast in node.args.args + node.args.kwonlyargs:
name = arg_ast.arg
if arg_ast.annotation and (name not in types or type_requires_eval(types[name])):
try:
if isinstance(arg_ast.annotation, ast.Constant) and arg_ast.annotation.value in aliases:
types[name] = aliases[arg_ast.annotation.value]
else:
types[name] = get_arg_type(arg_ast.annotation, aliases)
except Exception as ex3:
types[name] = ex3

if all(isinstance(t, Exception) for t in types.values()):
raise ex or next(iter(types.values()))

return types


def evaluate_postponed_annotations(params, component, logger):
if sys.version_info[:2] == (3, 6) or not (params and any(type_requires_eval(p.annotation) for p in params)):
return
try:
if sys.version_info < (3, 10):
types = get_types(component, logger)
else:
types = get_type_hints(component)
except Exception as ex:
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
return
for param in params:
if param.name in types:
param_type = types[param.name]
if isinstance(param_type, Exception):
logger.debug(f"Unable to evaluate type of {param.name} from {component}", exc_info=param_type)
continue
param.annotation = param_type
4 changes: 2 additions & 2 deletions jsonargparse/_link_arguments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Code related to argument linking."""

import inspect
import re
from argparse import SUPPRESS
from argparse import Action as ArgparseAction
Expand All @@ -18,6 +17,7 @@
filter_default_actions,
)
from ._namespace import Namespace, split_key_leaf
from ._parameter_resolvers import get_signature_parameters
from ._type_checking import ArgumentParser, _ArgumentGroup

__all__ = ["ArgumentLinking"]
Expand Down Expand Up @@ -275,7 +275,7 @@ def apply_parsing_links(parser: "ArgumentParser", cfg: Namespace) -> None:
value = value.as_dict()
else:
# Automatic namespace to dict based on compute_fn param type hint
params = list(inspect.signature(action.compute_fn).parameters.values())
params = get_signature_parameters(action.compute_fn)
for n, param in enumerate(params):
if (
n < len(args)
Expand Down
3 changes: 3 additions & 0 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from ._backports import evaluate_postponed_annotations
from ._common import is_dataclass_like, is_subclass
from ._optionals import parse_docs
from ._stubs_resolver import get_stub_types
Expand Down Expand Up @@ -275,6 +276,7 @@ def get_signature_parameters_and_indexes(component, parent, logger):
component=component,
**{a: getattr(param, a) for a in parameter_attributes},
)
evaluate_postponed_annotations(params, signature_source, logger)
stubs = get_stub_types(params, signature_source, parent, logger)
return params, args_idx, kwargs_idx, doc_params, stubs

Expand Down Expand Up @@ -824,6 +826,7 @@ def get_parameters_from_pydantic_or_attrs(
component=function_or_class,
)
)
evaluate_postponed_annotations(params, function_or_class, logger)
return params


Expand Down
Loading

0 comments on commit f36cb9e

Please sign in to comment.