Skip to content

Commit

Permalink
Support for Protocol types only accepting exact matching signature of…
Browse files Browse the repository at this point in the history
… public methods (#526)
  • Loading branch information
mauvilsa authored Jun 25, 2024
1 parent 0525fd2 commit 0a4de31
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 10 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ paths are considered internals and can change in minor and patch releases.
v4.31.0 (2024-06-??)
--------------------

Added
^^^^^
- Support for ``Protocol`` types only accepting exact matching signature of
public methods (`#526
<https://github.com/omni-us/jsonargparse/pull/526>`__).

Fixed
^^^^^
- Resolving of import paths for some ``torch`` functions not working (`#535
Expand Down
4 changes: 4 additions & 0 deletions DOCUMENTATION.rst
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ Some notes about this support are:
:py:meth:`.ArgumentParser.instantiate_classes` can be used to instantiate all
classes in a config object. For more details see :ref:`sub-classes`.

- ``Protocol`` types are also supported the same as sub-classes. The protocols
are not required to be ``runtime_checkable``. But the accepted classes must
match exactly the signature of the protocol's public methods.

- ``dataclasses`` are supported even when nested. Final classes, attrs'
``define`` decorator, and pydantic's ``dataclass`` decorator and ``BaseModel``
classes are supported and behave like standard dataclasses. For more details
Expand Down
5 changes: 3 additions & 2 deletions jsonargparse/_postponed_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def evaluate_postponed_annotations(params, component, parent, logger):
param.annotation = param_type


def get_return_type(component, logger):
def get_return_type(component, logger=None):
return_type = inspect.signature(component).return_annotation
if type_requires_eval(return_type):
global_vars = vars(import_module(component.__module__))
Expand All @@ -343,6 +343,7 @@ def get_return_type(component, logger):
if isinstance(return_type, ForwardRef):
return_type = resolve_forward_refs(return_type.__forward_arg__, global_vars, logger)
except Exception as ex:
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
if logger:
logger.debug(f"Unable to evaluate types for {component}", exc_info=ex)
return None
return return_type
56 changes: 48 additions & 8 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,17 +977,17 @@ def adapt_typehints(

# Subclass
elif not hasattr(typehint, "__origin__") and inspect.isclass(typehint):
if isinstance(val, typehint):
if is_instance_or_supports_protocol(val, typehint):
if serialize:
val = serialize_class_instance(val)
return val
if serialize and isinstance(val, str):
return val

val_input = val
if prev_val is None and not inspect.isabstract(typehint):
if prev_val is None and not inspect.isabstract(typehint) and not is_protocol(typehint):
with suppress(ValueError):
prev_val = Namespace(class_path=get_import_path(typehint))
prev_val = Namespace(class_path=get_import_path(typehint)) # implicit class_path
val = subclass_spec_as_namespace(val, prev_val)
if not is_subclass_spec(val):
raise_unexpected_value(
Expand All @@ -1000,20 +1000,20 @@ def adapt_typehints(

try:
val_class = import_object(resolve_class_path_by_name(typehint, val["class_path"]))
if isinstance(val_class, typehint):
return val_class
if is_instance_or_supports_protocol(val_class, typehint):
return val_class # importable instance
not_subclass = False
if not is_subclass(val_class, typehint):
if not is_subclass_or_implements_protocol(val_class, typehint):
not_subclass = True
if not inspect.isclass(val_class) and callable(val_class):
from ._postponed_annotations import get_return_type

return_type = get_return_type(val_class, logger)
if is_subclass(return_type, typehint):
if is_subclass_or_implements_protocol(return_type, typehint):
not_subclass = False
if not_subclass:
raise_unexpected_value(
f'Import path {val["class_path"]} does not correspond to a subclass of {typehint}'
f"Import path {val['class_path']} does not correspond to a subclass of {typehint.__name__}"
)
val["class_path"] = get_import_path(val_class)
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
Expand All @@ -1029,6 +1029,46 @@ def adapt_typehints(
return val


def implements_protocol(value, protocol) -> bool:
from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._postponed_annotations import get_return_type

if not inspect.isclass(value):
return False
members = 0
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):
if name.startswith("_"):
continue
if not hasattr(value, name):
return False
members += 1
proto_params = get_signature_parameters(protocol, name)
value_params = get_signature_parameters(value, name)
if [(p.name, p.annotation) for p in proto_params] != [(p.name, p.annotation) for p in value_params]:
return False
proto_return = get_return_type(inspect.getattr_static(protocol, name))
value_return = get_return_type(inspect.getattr_static(value, name))
if proto_return != value_return:
return False
return True if members else False


def is_protocol(class_type) -> bool:
return getattr(class_type, "_is_protocol", False)


def is_subclass_or_implements_protocol(value, class_type) -> bool:
if is_protocol(class_type):
return implements_protocol(value, class_type)
return is_subclass(value, class_type)


def is_instance_or_supports_protocol(value, class_type):
if is_protocol(class_type):
return is_subclass_or_implements_protocol(value.__class__, class_type)
return isinstance(value, class_type)


def is_subclass_spec(val):
is_class = isinstance(val, (dict, Namespace)) and "class_path" in val
if is_class:
Expand Down
82 changes: 82 additions & 0 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
Namespace,
lazy_instance,
)
from jsonargparse._optionals import typing_extensions_import
from jsonargparse._typehints import implements_protocol, is_instance_or_supports_protocol
from jsonargparse.typing import final
from jsonargparse_tests.conftest import (
capture_logs,
Expand All @@ -32,6 +34,8 @@
source_unavailable,
)

Protocol = typing_extensions_import("Protocol")


@pytest.mark.parametrize("type", [Calendar, Optional[Calendar]])
def test_subclass_basics(parser, type):
Expand Down Expand Up @@ -1407,6 +1411,84 @@ def test_subclass_signature_instance_default(parser):
assert "cal: Unable to serialize instance <calendar.Calendar " in dump


# protocol tests


class Interface(Protocol): # type: ignore[valid-type,misc]
def predict(self, items: List[float]) -> List[float]: ... # type: ignore[empty-body]


class ImplementsInterface:
def __init__(self, batch_size: int):
self.batch_size = batch_size

def predict(self, items: List[float]) -> List[float]:
return items


class NotImplementsInterface1:
def predict(self, items: str) -> List[float]:
return []


class NotImplementsInterface2:
def predict(self, items: List[float], extra: int) -> List[float]:
return items


class NotImplementsInterface3:
def predict(self, items: List[float]) -> None:
return


@pytest.mark.parametrize(
"expected, value",
[
(True, ImplementsInterface),
(False, ImplementsInterface(1)),
(False, NotImplementsInterface1),
(False, NotImplementsInterface2),
(False, NotImplementsInterface3),
(False, object),
],
)
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
def test_implements_protocol(expected, value):
assert implements_protocol(value, Interface) is expected


@pytest.mark.parametrize(
"expected, value",
[
(False, ImplementsInterface),
(True, ImplementsInterface(1)),
(False, NotImplementsInterface1()),
(False, object),
],
)
@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
def test_is_instance_or_supports_protocol(expected, value):
assert is_instance_or_supports_protocol(value, Interface) is expected


@pytest.mark.skipif(not Protocol, reason="Requires Python 3.8+ or typing_extensions")
def test_parse_implements_protocol(parser):
parser.add_argument("--cls", type=Interface)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"])
assert cfg.cls.class_path == f"{__name__}.ImplementsInterface"
assert cfg.cls.init_args == Namespace(batch_size=5)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsInterface)
assert init.cls.batch_size == 5
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]
with pytest.raises(ArgumentError) as ctx:
parser.parse_args([f"--cls={__name__}.NotImplementsInterface1"])
ctx.match("does not correspond to a subclass of")
with pytest.raises(ArgumentError) as ctx:
parser.parse_args(['--cls={"batch_size": 5}'])
ctx.match("Not a valid subclass of Interface")


# parameter skip tests


Expand Down

0 comments on commit 0a4de31

Please sign in to comment.