Skip to content

Commit

Permalink
Fix: Callable that returns class not using required parameter default…
Browse files Browse the repository at this point in the history
… from lambda (#523)
  • Loading branch information
mauvilsa authored Jun 12, 2024
1 parent 54fde77 commit 3a3c6a0
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Fixed
<https://github.com/omni-us/jsonargparse/issues/516>`__).
- ``--print_config`` failing in some cases (`#517
<https://github.com/omni-us/jsonargparse/issues/517>`__).
- Callable that returns class not using required parameter default from lambda
(`#523 <https://github.com/omni-us/jsonargparse/pull/523>`__).


v4.29.0 (2024-05-24)
Expand Down
12 changes: 7 additions & 5 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,12 +1326,14 @@ def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
Returns:
A new object with the merged configuration.
"""
cfg = cfg_to.clone()
cfg_from = cfg_from.clone()
cfg_to = cfg_to.clone()
with parser_context(parent_parser=self):
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg, cfg_from)
cfg.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg)
return cfg
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
cfg_to.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg_to)
return cfg_to

def _check_value_key(self, action: argparse.Action, value: Any, key: str, cfg: Optional[Namespace]) -> Any:
"""Checks the value for a given action.
Expand Down
15 changes: 15 additions & 0 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,21 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg):
keys = keys[: num + 1] + [k for k in keys[num + 1 :] if not k.startswith(key + ".")]
num += 1

@staticmethod
def delete_init_args_required_none(cfg_from, cfg_to):
for key, val in cfg_from.items(branches=True):
if isinstance(val, Namespace) and val.get("class_path") and val.get("init_args"):
skip_keys = [
k
for k, v in val.init_args.__dict__.items()
if v is None and cfg_to.get(f"{key}.init_args.{k}") is not None
]
if skip_keys:
parser = ActionTypeHint.get_class_parser(val.class_path)
for skip_key in skip_keys:
if skip_key in parser.required_args:
del val.init_args[skip_key]

@staticmethod
@contextmanager
def subclass_arg_context(parser):
Expand Down
39 changes: 37 additions & 2 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,9 +816,10 @@ def __init__(self, optimizer: Optimizer, last_epoch: int = -1):


class ReduceLROnPlateau:
def __init__(self, optimizer: Optimizer, monitor: str):
def __init__(self, optimizer: Optimizer, monitor: str, factor: float = 0.1):
self.optimizer = optimizer
self.monitor = monitor
self.factor = factor


def test_callable_args_return_type_union_of_classes(parser, subtests):
Expand Down Expand Up @@ -846,7 +847,7 @@ def test_callable_args_return_type_union_of_classes(parser, subtests):
}
cfg = parser.parse_args([f"--scheduler={value}"])
assert f"{__name__}.ReduceLROnPlateau" == cfg.scheduler.class_path
assert Namespace(monitor="loss") == cfg.scheduler.init_args
assert Namespace(monitor="loss", factor=0.1) == cfg.scheduler.init_args
init = parser.instantiate_classes(cfg)
scheduler = init.scheduler(optimizer)
assert isinstance(scheduler, ReduceLROnPlateau)
Expand Down Expand Up @@ -948,6 +949,40 @@ def test_callable_zero_args_return_type_class(parser):
assert activation.negative_slope == 0.05


class ModelRequiredCallableArg:
def __init__(
self,
scheduler: Callable[[Optimizer], ReduceLROnPlateau] = lambda o: ReduceLROnPlateau(o, monitor="acc"),
):
self.scheduler = scheduler


def test_callable_return_class_required_arg_from_default(parser):
parser.add_argument("--cfg", action="config")
parser.add_argument("--model", type=ModelRequiredCallableArg)

cfg = parser.parse_args(["--model=ModelRequiredCallableArg"])
assert cfg.model.init_args.scheduler.class_path == f"{__name__}.ReduceLROnPlateau"
assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.1)

config = {
"model": {
"class_path": f"{__name__}.ModelRequiredCallableArg",
"init_args": {
"scheduler": {
"class_path": f"{__name__}.ReduceLROnPlateau",
"init_args": {
"factor": 0.5,
},
},
},
}
}
cfg = parser.parse_args([f"--cfg={config}"])
assert cfg.model.init_args.scheduler.class_path == f"{__name__}.ReduceLROnPlateau"
assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.5)


# lazy_instance tests


Expand Down

0 comments on commit 3a3c6a0

Please sign in to comment.