Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PEP 692 #586

Merged
merged 13 commits into from
Oct 11, 2024
7 changes: 6 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.33.3 (2024-10-??)
v4.34.0 (2024-10-??)
--------------------

Added
^^^^^
- Support for PEP 692 (i.e., ``Unpack[TypedDict]`` annotations for ``**kwargs``)
(`#586 <https://github.com/omni-us/jsonargparse/pull/586>`)

Fixed
^^^^^
- Empty tuples are now parsed correctly instead of raising an error.
Expand Down
2 changes: 2 additions & 0 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ Some notes about this support are:
``OrderedDict``, and ``TypedDict`` are supported but only with ``str`` or
``int`` keys. ``Required`` and ``NotRequired`` are also supported for
fine-grained specification of required/optional ``TypedDict`` keys.
``Unpack`` is supported with ``TypedDict`` for more precise ``**kwargs``
typing as described in PEP `692 <https://peps.python.org/pep-0692/>`__.
For more details see :ref:`dict-items`.

- ``Tuple``, ``Set`` and ``MutableSet`` are supported even though they can't be
Expand Down
13 changes: 13 additions & 0 deletions jsonargparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from ._namespace import Namespace
from ._optionals import (
capture_typing_extension_shadows,
get_alias_target,
get_annotated_base_type,
import_reconplogger,
is_alias_type,
is_annotated,
reconplogger_support,
typing_extensions_import,
)
from ._type_checking import ArgumentParser

Expand All @@ -36,6 +38,13 @@

ClassType = TypeVar("ClassType")

_UnpackGenericAlias = typing_extensions_import("_UnpackAlias")

unpack_meta_types = set()
if _UnpackGenericAlias:
unpack_meta_types.add(_UnpackGenericAlias)
capture_typing_extension_shadows(_UnpackGenericAlias, "_UnpackGenericAlias", unpack_meta_types)


class InstantiatorCallable(Protocol):
def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
Expand Down Expand Up @@ -96,6 +105,10 @@ def is_generic_class(cls) -> bool:
return isinstance(cls, _GenericAlias) and getattr(cls, "__module__", "") != "typing"


def is_unpack_typehint(cls) -> bool:
return any(isinstance(cls, unpack_type) for unpack_type in unpack_meta_types)


def get_generic_origin(cls):
return cls.__origin__ if is_generic_class(cls) else cls

Expand Down
1 change: 1 addition & 0 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,7 @@ def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
with parser_context(parent_parser=self):
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
ActionTypeHint.delete_not_required_args(cfg_from, cfg_to)
cfg_to.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg_to)
return cfg_to
Expand Down
11 changes: 11 additions & 0 deletions jsonargparse/_optionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def typing_extensions_import(name):
return getattr(__import__("typing"), name, False)


def capture_typing_extension_shadows(typehint, name: str, *collections) -> None:
"""
Ensure different origins for types in typing_extensions are captured.
"""
if (typehint is False or getattr(typehint, "__module__", None) == "typing_extensions") and hasattr(
__import__("typing"), name
):
for collection in collections:
collection.add(getattr(__import__("typing"), name))


def final(cls):
"""Decorator to make a class ``final``, i.e., it shouldn't be subclassed.

Expand Down
33 changes: 33 additions & 0 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_dataclass_like,
is_generic_class,
is_subclass,
is_unpack_typehint,
parse_logger,
)
from ._optionals import get_annotated_base_type, is_annotated, is_pydantic_model, parse_docs
Expand All @@ -28,6 +29,7 @@
from ._util import (
ClassFromFunctionBase,
get_import_path,
get_typehint_args,
get_typehint_origin,
iter_to_set_str,
unique,
Expand Down Expand Up @@ -318,6 +320,35 @@ def replace_type_vars(annotation):
param.annotation = replace_type_vars(param.annotation)


def unpack_typed_dict_kwargs(params: ParamList, kwargs_idx: int) -> int:
kwargs = params[kwargs_idx]
annotation = kwargs.annotation
if is_unpack_typehint(annotation):
params.pop(kwargs_idx)
annotation_args = get_typehint_args(annotation)
assert len(annotation_args) == 1, "Unpack requires a single type argument"
dict_annotations = annotation_args[0].__annotations__
new_params = []
for nm, annot in dict_annotations.items():
new_params.append(
ParamData(
name=nm,
annotation=annot,
default=inspect._empty,
kind=inspect._ParameterKind.KEYWORD_ONLY,
doc=None,
component=kwargs.component,
parent=kwargs.parent,
origin=kwargs.origin,
)
)
# insert in-place
assert kwargs_idx == len(params), "trailing params should yield a syntax error"
params.extend(new_params)
return -1
return kwargs_idx


def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component) -> None:
if not stubs:
return
Expand Down Expand Up @@ -838,6 +869,8 @@ def get_parameters(self) -> ParamList:
self.component, self.parent, self.logger
)
self.replace_param_default_subclass_specs(params)
if kwargs_idx >= 0:
kwargs_idx = unpack_typed_dict_kwargs(params, kwargs_idx)
if args_idx >= 0 or kwargs_idx >= 0:
self.doc_params = doc_params
with mro_context(self.parent):
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
callable_instances,
get_subclass_names,
is_optional,
not_required_types,
)
from ._util import NoneType, get_private_kwargs, iter_to_set_str
from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
from .typing import register_pydantic_type

__all__ = [
Expand Down Expand Up @@ -322,7 +323,7 @@ def _add_signature_parameter(
default = param.default
if default == inspect_empty and is_optional(annotation):
default = None
is_required = default == inspect_empty
is_required = default == inspect_empty and get_typehint_origin(annotation) not in not_required_types
src = get_parameter_origins(param.component, param.parent)
skip_message = f'Skipping parameter "{name}" from "{src}" because of: '
if not fail_untyped and annotation == inspect_empty:
Expand Down
16 changes: 13 additions & 3 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from ._namespace import Namespace
from ._optionals import (
capture_typing_extension_shadows,
get_alias_target,
is_alias_type,
is_annotated,
Expand Down Expand Up @@ -93,6 +94,7 @@
NotRequired = typing_extensions_import("NotRequired")
Required = typing_extensions_import("Required")
_TypedDictMeta = typing_extensions_import("_TypedDictMeta")
Unpack = typing_extensions_import("Unpack")


def _capture_typing_extension_shadows(name: str, *collections) -> None:
Expand All @@ -101,9 +103,7 @@ def _capture_typing_extension_shadows(name: str, *collections) -> None:
"""
current_module = sys.modules[__name__]
typehint = getattr(current_module, name)
if getattr(typehint, "__module__", None) == "typing_extensions" and hasattr(__import__("typing"), name):
for collection in collections:
collection.add(getattr(__import__("typing"), name))
return capture_typing_extension_shadows(typehint, name, *collections)


root_types = {
Expand Down Expand Up @@ -142,6 +142,7 @@ def _capture_typing_extension_shadows(name: str, *collections) -> None:
abc.Callable,
NotRequired,
Required,
Unpack,
}

leaf_types = {
Expand Down Expand Up @@ -193,6 +194,9 @@ def _capture_typing_extension_shadows(name: str, *collections) -> None:
typed_dict_meta_types = {_TypedDictMeta}
_capture_typing_extension_shadows("_TypedDictMeta", typed_dict_meta_types)

unpack_types = {Unpack}
_capture_typing_extension_shadows("Unpack", unpack_types)

subclass_arg_parser: ContextVar = ContextVar("subclass_arg_parser")
allow_default_instance: ContextVar = ContextVar("allow_default_instance", default=False)
sub_defaults: ContextVar = ContextVar("sub_defaults", default=False)
Expand Down Expand Up @@ -441,6 +445,12 @@ def delete_init_args_required_none(cfg_from, cfg_to):
if skip_key in parser.required_args:
del val.init_args[skip_key]

@staticmethod
def delete_not_required_args(cfg_from, cfg_to):
for key, val in list(cfg_to.items(branches=True)):
if val == inspect._empty and key not in cfg_from:
del cfg_to[key]

@staticmethod
@contextmanager
def subclass_arg_context(parser):
Expand Down
4 changes: 4 additions & 0 deletions jsonargparse/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def object_path_serializer(value):
raise ValueError(f"Only possible to serialize an importable object, given {value}: {ex}") from ex


def get_typehint_args(typehint):
return getattr(typehint, "__args__", tuple())


def get_typehint_origin(typehint):
if not hasattr(typehint, "__origin__"):
typehint_class = get_import_path(typehint.__class__)
Expand Down
74 changes: 74 additions & 0 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import json
import pickle
import random
import sys
import time
import uuid
from calendar import Calendar, TextCalendar
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from types import MappingProxyType
Expand Down Expand Up @@ -37,6 +39,7 @@
ActionTypeHint,
NotRequired,
Required,
Unpack,
get_all_subclass_paths,
get_subclass_types,
is_optional,
Expand Down Expand Up @@ -605,6 +608,77 @@ def test_typeddict_with_required_arg(parser):
ctx.match("Expected a <class 'int'>")


@pytest.mark.skipif(not Unpack, reason="Unpack introduced in python 3.11 or backported in typing_extensions")
def test_unpack_support(parser):
assert ActionTypeHint.is_supported_typehint(Unpack[Any])


if Unpack: # and Required and NotRequired
MyTestUnpackDict = TypedDict("MyTestUnpackDict", {"a": Required[int], "b": NotRequired[int]}, total=True)

class UnpackClass:

def __init__(self, **kwargs: Unpack[MyTestUnpackDict]) -> None:
self.a = kwargs["a"]
self.b = kwargs.get("b")

@dataclass
class MyTestUnpackClass:

test: UnpackClass

class MyTestInheritedUnpackClass(UnpackClass):

def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)


@pytest.mark.skipif(not Unpack, reason="Unpack introduced in python 3.11 or backported in typing_extensions")
@pytest.mark.parametrize(["init_args"], [({"a": 1},), ({"a": 2, "b": None},), ({"a": 3, "b": 1},)])
def test_valid_unpack_typeddict(parser, init_args):
parser.add_argument("--testclass", type=MyTestUnpackClass)
test_config = {"test": {"class_path": f"{__name__}.UnpackClass", "init_args": init_args}}
cfg = parser.parse_args([f"--testclass={json.dumps(test_config)}"])
assert test_config == cfg["testclass"].as_dict()
# also assert no issues with dumping
if test_config["test"]["init_args"].get("b") is None:
# parser.dump does not dump null b
test_config["test"]["init_args"].pop("b", None)
assert json.dumps({"testclass": test_config}).replace(" ", "") == parser.dump(cfg, format="json")


@pytest.mark.skipif(not Unpack, reason="Unpack introduced in python 3.11 or backported in typing_extensions")
@pytest.mark.parametrize(["init_args"], [({},), ({"b": None},), ({"b": 1},)])
def test_invalid_unpack_typeddict(parser, init_args):
parser.add_argument("--testclass", type=MyTestUnpackClass)
test_config = {"test": {"class_path": f"{__name__}.UnpackClass", "init_args": init_args}}
with pytest.raises(ArgumentError):
parser.parse_args([f"--testclass={json.dumps(test_config)}"])


@pytest.mark.skipif(not Unpack, reason="Unpack introduced in python 3.11 or backported in typing_extensions")
@pytest.mark.parametrize(["init_args"], [({"a": 1},), ({"a": 2, "b": None},), ({"a": 3, "b": 1},)])
def test_valid_inherited_unpack_typeddict(parser, init_args):
parser.add_argument("--testclass", type=MyTestInheritedUnpackClass)
test_config = {"class_path": f"{__name__}.MyTestInheritedUnpackClass", "init_args": init_args}
cfg = parser.parse_args([f"--testclass={json.dumps(test_config)}"])
assert test_config == cfg["testclass"].as_dict()
# also assert no issues with dumping
if test_config["init_args"].get("b") is None:
# parser.dump does not dump null b
test_config["init_args"].pop("b", None)
assert json.dumps({"testclass": test_config}).replace(" ", "") == parser.dump(cfg, format="json")


@pytest.mark.skipif(not Unpack, reason="Unpack introduced in python 3.11 or backported in typing_extensions")
@pytest.mark.parametrize(["init_args"], [({},), ({"b": None},), ({"b": 1},)])
def test_invalid_inherited_unpack_typeddict(parser, init_args):
parser.add_argument("--testclass", type=MyTestInheritedUnpackClass)
test_config = {"class_path": f"{__name__}.MyTestInheritedUnpackClass", "init_args": init_args}
with pytest.raises(ArgumentError):
parser.parse_args([f"--testclass={json.dumps(test_config)}"])


def test_mapping_proxy_type(parser):
parser.add_argument("--mapping", type=MappingProxyType)
cfg = parser.parse_args(['--mapping={"x":1}'])
Expand Down
Loading