Skip to content

Commit

Permalink
Fixed override of Callable init_args without passing the class_path #174
Browse files Browse the repository at this point in the history
.
  • Loading branch information
mauvilsa committed Oct 14, 2022
1 parent babb93c commit 04383f9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Fixed
- ``default_env`` not forwarded to subcommand parsers, causing environment
variable names to not be shown in subcommand help `pytorch-lightning#12790
<https://github.com/Lightning-AI/lightning/issues/12790>`__.
- Cannot override Callable ``init_args`` without passing the ``class_path``
`#174 <https://github.com/omni-us/jsonargparse/issues/174>`__.


v4.15.1 (2022-10-07)
Expand Down
17 changes: 13 additions & 4 deletions jsonargparse/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def is_mapping_typehint(typehint):
return False


@staticmethod
def is_callable_typehint(typehint):
typehint_origin = get_typehint_origin(typehint)
return typehint_origin in callable_origin_types or typehint in callable_origin_types


def is_init_arg_mapping_typehint(self, key, cfg):
result = False
class_path = cfg.get(f'{self.dest}.class_path')
Expand All @@ -259,8 +265,9 @@ def parse_argv_item(arg_string):
action = _find_parent_action(parser, arg_base[2:])

typehint = typehint_from_action(action)
if (
if typehint and (
ActionTypeHint.is_subclass_typehint(typehint, all_subtypes=False) or
ActionTypeHint.is_callable_typehint(typehint) or
ActionTypeHint.is_mapping_typehint(typehint)
):
return action, arg_base, explicit_arg
Expand Down Expand Up @@ -629,14 +636,16 @@ def adapt_typehints(val, typehint, serialize=False, instantiate_classes=False, p
if isinstance(val, str):
val_obj = import_object(val)
if inspect.isclass(val_obj):
val = {'class_path': val}
val = Namespace(class_path=val)
elif callable(val_obj):
val = val_obj
else:
raise ImportError(f'Unexpected import object {val_obj}')
if isinstance(val, (dict, Namespace)):
if isinstance(val, (dict, Namespace, NestedArg)):
val_input = val
val = subclass_spec_as_namespace(val, prev_val)
if not is_subclass_spec(val):
raise ImportError(f'Dict must include a class_path and optionally init_args, but got {val}')
raise ImportError(f'Dict must include a class_path and optionally init_args, but got {val_input}')
val_class = import_object(val['class_path'])
if not (inspect.isclass(val_class) and callable_instances(val_class)):
raise ImportError(f'{val["class_path"]!r} is not a callable class.')
Expand Down
18 changes: 18 additions & 0 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,24 @@ class MyFunc2(MyFunc1):
self.assertRaises(ParserError, lambda: parser.parse_args([f'--callable={json.dumps(value)}']))


def test_callable_with_class_path_short_init_args(self):
class MyCallable:
def __init__(self, name: str):
self.name = name
def __call__(self):
return self.name

parser = ArgumentParser()
parser.add_argument('--call', type=Callable)

with mock_module(MyCallable) as module:
cfg = parser.parse_args([f'--call={module}.MyCallable', '--call.name=Bob'])
self.assertEqual(cfg.call.class_path, f'{module}.MyCallable')
self.assertEqual(cfg.call.init_args, Namespace(name='Bob'))
init = parser.instantiate_classes(cfg)
self.assertEqual(init.call(), 'Bob')


def test_typed_Callable_with_function_path(self):
def my_func_1(p: int) -> str:
return str(p)
Expand Down

0 comments on commit 04383f9

Please sign in to comment.