Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Protocol types only accepting exact matching signature of public methods #526

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ paths are considered internals and can change in minor and patch releases.
v4.31.0 (2024-06-??)
--------------------

Added
^^^^^
- Support for ``Protocol`` types only accepting exact matching signature of
public methods (`#526
<https://github.com/omni-us/jsonargparse/pull/526>`__).

Fixed
^^^^^
- Resolving of import paths for some ``torch`` functions not working (`#535
Expand Down
4 changes: 4 additions & 0 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ Some notes about this support are:
:py:meth:`.ArgumentParser.instantiate_classes` can be used to instantiate all
classes in a config object. For more details see :ref:`sub-classes`.

- ``Protocol`` types are also supported the same as sub-classes. The protocols
are not required to be ``runtime_checkable``. But the accepted classes must
match exactly the signature of the protocol's public methods.

- ``dataclasses`` are supported even when nested. Final classes, attrs'
``define`` decorator, and pydantic's ``dataclass`` decorator and ``BaseModel``
classes are supported and behave like standard dataclasses. For more details
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_postponed_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def evaluate_postponed_annotations(params, component, parent, logger):
param.annotation = param_type


def get_return_type(component, logger):
def get_return_type(component, logger=None):
return_type = inspect.signature(component).return_annotation
if type_requires_eval(return_type):
global_vars = vars(import_module(component.__module__))
Expand All @@ -343,6 +343,7 @@ def get_return_type(component, logger):
if isinstance(return_type, ForwardRef):
return_type = resolve_forward_refs(return_type.__forward_arg__, global_vars, logger)
except Exception as ex:
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
if logger:
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
return None
return return_type
56 changes: 48 additions & 8 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,17 +977,17 @@ def adapt_typehints(

# Subclass
elif not hasattr(typehint, "__origin__") and inspect.isclass(typehint):
if isinstance(val, typehint):
if is_instance_or_supports_protocol(val, typehint):
if serialize:
val = serialize_class_instance(val)
return val
if serialize and isinstance(val, str):
return val

val_input = val
if prev_val is None and not inspect.isabstract(typehint):
if prev_val is None and not inspect.isabstract(typehint) and not is_protocol(typehint):
with suppress(ValueError):
prev_val = Namespace(class_path=get_import_path(typehint))
prev_val = Namespace(class_path=get_import_path(typehint)) # implicit class_path
val = subclass_spec_as_namespace(val, prev_val)
if not is_subclass_spec(val):
raise_unexpected_value(
Expand All @@ -1000,20 +1000,20 @@ def adapt_typehints(

try:
val_class = import_object(resolve_class_path_by_name(typehint, val["class_path"]))
if isinstance(val_class, typehint):
return val_class
if is_instance_or_supports_protocol(val_class, typehint):
return val_class # importable instance
not_subclass = False
if not is_subclass(val_class, typehint):
if not is_subclass_or_implements_protocol(val_class, typehint):
not_subclass = True
if not inspect.isclass(val_class) and callable(val_class):
from ._postponed_annotations import get_return_type

return_type = get_return_type(val_class, logger)
if is_subclass(return_type, typehint):
if is_subclass_or_implements_protocol(return_type, typehint):
not_subclass = False
if not_subclass:
raise_unexpected_value(
f'Import path {val["class_path"]} does not correspond to a subclass of {typehint}'
f"Import path {val['class_path']} does not correspond to a subclass of {typehint.__name__}"
)
val["class_path"] = get_import_path(val_class)
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
Expand All @@ -1029,6 +1029,46 @@ def adapt_typehints(
return val


def implements_protocol(value, protocol) -> bool:
from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._postponed_annotations import get_return_type

if not inspect.isclass(value):
return False
members = 0
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
if name.startswith("_"):
continue
if not hasattr(value, name):
return False
members += 1
proto_params = get_signature_parameters(protocol, name)
value_params = get_signature_parameters(value, name)
if [(p.name, p.annotation) for p in proto_params] != [(p.name, p.annotation) for p in value_params]:
return False
proto_return = get_return_type(inspect.getattr_static(protocol, name))
value_return = get_return_type(inspect.getattr_static(value, name))
if proto_return != value_return:
return False
return True if members else False


def is_protocol(class_type) -> bool:
return getattr(class_type, "_is_protocol", False)


def is_subclass_or_implements_protocol(value, class_type) -> bool:
if is_protocol(class_type):
return implements_protocol(value, class_type)
return is_subclass(value, class_type)


def is_instance_or_supports_protocol(value, class_type):
if is_protocol(class_type):
return is_subclass_or_implements_protocol(value.__class__, class_type)
return isinstance(value, class_type)


def is_subclass_spec(val):
is_class = isinstance(val, (dict, Namespace)) and "class_path" in val
if is_class:
Expand Down
82 changes: 82 additions & 0 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
Namespace,
lazy_instance,
)
from jsonargparse._optionals import typing_extensions_import
from jsonargparse._typehints import implements_protocol, is_instance_or_supports_protocol
from jsonargparse.typing import final
from jsonargparse_tests.conftest import (
capture_logs,
Expand All @@ -32,6 +34,8 @@
source_unavailable,
)

Protocol = typing_extensions_import("Protocol")


@pytest.mark.parametrize("type", [Calendar, Optional[Calendar]])
def test_subclass_basics(parser, type):
Expand Down Expand Up @@ -1407,6 +1411,84 @@ def test_subclass_signature_instance_default(parser):
assert "cal: Unable to serialize instance <calendar.Calendar " in dump


# protocol tests


class Interface(Protocol): # type: ignore[valid-type,misc]
def predict(self, items: List[float]) -> List[float]: ... # type: ignore[empty-body]
Dismissed Show dismissed Hide dismissed


class ImplementsInterface:
def __init__(self, batch_size: int):
self.batch_size = batch_size

def predict(self, items: List[float]) -> List[float]:
return items


class NotImplementsInterface1:
def predict(self, items: str) -> List[float]:
return []


class NotImplementsInterface2:
def predict(self, items: List[float], extra: int) -> List[float]:
return items


class NotImplementsInterface3:
def predict(self, items: List[float]) -> None:
return


@pytest.mark.parametrize(
"expected, value",
[
(True, ImplementsInterface),
(False, ImplementsInterface(1)),
(False, NotImplementsInterface1),
(False, NotImplementsInterface2),
(False, NotImplementsInterface3),
(False, object),
],
)
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
def test_implements_protocol(expected, value):
assert implements_protocol(value, Interface) is expected


@pytest.mark.parametrize(
"expected, value",
[
(False, ImplementsInterface),
(True, ImplementsInterface(1)),
(False, NotImplementsInterface1()),
(False, object),
],
)
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
def test_is_instance_or_supports_protocol(expected, value):
assert is_instance_or_supports_protocol(value, Interface) is expected


@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
def test_parse_implements_protocol(parser):
parser.add_argument("--cls", type=Interface)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"])
assert cfg.cls.class_path == f"{__name__}.ImplementsInterface"
assert cfg.cls.init_args == Namespace(batch_size=5)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsInterface)
assert init.cls.batch_size == 5
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]
with pytest.raises(ArgumentError) as ctx:
parser.parse_args([f"--cls={__name__}.NotImplementsInterface1"])
ctx.match("does not correspond to a subclass of")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--cls={"batch_size": 5}'])
ctx.match("Not a valid subclass of Interface")


# parameter skip tests


Expand Down