Skip to content

Commit

Permalink
Fix callable protocols failing to parse (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Nov 29, 2024
1 parent 0ccac3a commit 98e03d6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Fixed
<https://github.com/omni-us/jsonargparse/pull/625>`__).
- ``NotRequired`` incorrectly having ``inspect._empty`` as default (`#625
<https://github.com/omni-us/jsonargparse/pull/625>`__).
- Callable protocols failing to parse (`#637
<https://github.com/omni-us/jsonargparse/pull/637>`__).


v4.34.0 (2024-11-08)
Expand Down
15 changes: 10 additions & 5 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,8 +1042,9 @@ def adapt_typehints(
prev_val = Namespace(class_path=get_import_path(typehint)) # implicit class_path
val = subclass_spec_as_namespace(val, prev_val)
if not is_subclass_spec(val):
msg = "Does not implement protocol" if is_protocol(typehint) else "Not a valid subclass of"
raise_unexpected_value(
f"Not a valid subclass of {typehint.__name__}. Got value: {val_input}\n"
f"{msg} {typehint.__name__}. Got value: {val_input}\n"
"Subclass types expect one of:\n"
"- a class path (str)\n"
"- a dict with class_path entry\n"
Expand All @@ -1054,6 +1055,8 @@ def adapt_typehints(
val_class = import_object(resolve_class_path_by_name(typehint, val["class_path"]))
if is_instance_or_supports_protocol(val_class, typehint):
return val_class # importable instance
if is_protocol(val_class):
raise_unexpected_value(f"Expected an instantiatable class, but {val['class_path']} is a protocol")
not_subclass = False
if not is_subclass_or_implements_protocol(val_class, typehint):
not_subclass = True
Expand All @@ -1064,9 +1067,8 @@ def adapt_typehints(
if is_subclass_or_implements_protocol(return_type, typehint):
not_subclass = False
if not_subclass:
raise_unexpected_value(
f"Import path {val['class_path']} does not correspond to a subclass of {typehint.__name__}"
)
msg = "implement protocol" if is_protocol(typehint) else "correspond to a subclass of"
raise_unexpected_value(f"Import path {val['class_path']} does not {msg} {typehint.__name__}")
val["class_path"] = get_import_path(val_class)
val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
except (ImportError, AttributeError, AssertionError, ArgumentError) as ex:
Expand Down Expand Up @@ -1111,8 +1113,11 @@ def implements_protocol(value, protocol) -> bool:
if not hasattr(value, name):
return False
members += 1
try:
value_params = get_signature_parameters(value, name)
except ValueError:
return False
proto_params = get_signature_parameters(protocol, name)
value_params = get_signature_parameters(value, name)
if [(p.name, p.annotation) for p in proto_params] != [(p.name, p.annotation) for p in value_params]:
return False
proto_return = get_return_type(inspect.getattr_static(protocol, name))
Expand Down
30 changes: 23 additions & 7 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,12 +1492,12 @@ def test_parse_implements_protocol(parser):
assert isinstance(init.cls, ImplementsInterface)
assert init.cls.batch_size == 5
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]
with pytest.raises(ArgumentError) as ctx:
with pytest.raises(ArgumentError, match="is a protocol"):
parser.parse_args([f"--cls={__name__}.Interface"])
with pytest.raises(ArgumentError, match="does not implement protocol"):
parser.parse_args([f"--cls={__name__}.NotImplementsInterface1"])
ctx.match("does not correspond to a subclass of")
with pytest.raises(ArgumentError) as ctx:
with pytest.raises(ArgumentError, match="Does not implement protocol Interface"):
parser.parse_args(['--cls={"batch_size": 5}'])
ctx.match("Not a valid subclass of Interface")


# callable protocol tests
Expand All @@ -1507,7 +1507,7 @@ class CallableInterface(Protocol):
def __call__(self, items: List[float]) -> List[float]: ...


class ImplementsCallableInterface1:
class ImplementsCallableInterface:
def __init__(self, batch_size: int):
self.batch_size = batch_size

Expand All @@ -1533,8 +1533,8 @@ def __call__(self, items: List[float]) -> None:
@pytest.mark.parametrize(
"expected, value",
[
(True, ImplementsCallableInterface1),
(False, ImplementsCallableInterface1(1)),
(True, ImplementsCallableInterface),
(False, ImplementsCallableInterface(1)),
(False, NotImplementsCallableInterface1),
(False, NotImplementsCallableInterface2),
(False, NotImplementsCallableInterface3),
Expand All @@ -1545,6 +1545,22 @@ def test_implements_callable_protocol(expected, value):
assert implements_protocol(value, CallableInterface) is expected


def test_parse_implements_callable_protocol(parser):
parser.add_argument("--cls", type=CallableInterface)
cfg = parser.parse_args([f"--cls={__name__}.ImplementsCallableInterface", "--cls.batch_size=7"])
assert cfg.cls.class_path == f"{__name__}.ImplementsCallableInterface"
assert cfg.cls.init_args == Namespace(batch_size=7)
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, ImplementsCallableInterface)
assert init.cls([1.0, 2.0]) == [1.0, 2.0]
with pytest.raises(ArgumentError, match="is a protocol"):
parser.parse_args([f"--cls={__name__}.CallableInterface"])
with pytest.raises(ArgumentError, match="does not implement protocol"):
parser.parse_args([f"--cls={__name__}.NotImplementsCallableInterface1"])
with pytest.raises(ArgumentError, match="Does not implement protocol CallableInterface"):
parser.parse_args(['--cls={"batch_size": 7}'])


# parameter skip tests


Expand Down

0 comments on commit 98e03d6

Please sign in to comment.