From 3a3c6a0ebfc18901d80c862040deda2d03a81f9d Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 12 Jun 2024 07:27:49 +0200 Subject: [PATCH] Fix: Callable that returns class not using required parameter default from lambda (#523) --- CHANGELOG.rst | 2 ++ jsonargparse/_core.py | 12 +++++---- jsonargparse/_typehints.py | 15 +++++++++++ jsonargparse_tests/test_typehints.py | 39 ++++++++++++++++++++++++++-- 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d1dbe87b..232e5aad 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,6 +30,8 @@ Fixed `__). - ``--print_config`` failing in some cases (`#517 `__). +- Callable that returns class not using required parameter default from lambda + (`#523 `__). v4.29.0 (2024-05-24) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 5e499608..546c69aa 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -1326,12 +1326,14 @@ def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace: Returns: A new object with the merged configuration. """ - cfg = cfg_to.clone() + cfg_from = cfg_from.clone() + cfg_to = cfg_to.clone() with parser_context(parent_parser=self): - ActionTypeHint.discard_init_args_on_class_path_change(self, cfg, cfg_from) - cfg.update(cfg_from) - ActionTypeHint.apply_appends(self, cfg) - return cfg + ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from) + ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to) + cfg_to.update(cfg_from) + ActionTypeHint.apply_appends(self, cfg_to) + return cfg_to def _check_value_key(self, action: argparse.Action, value: Any, key: str, cfg: Optional[Namespace]) -> Any: """Checks the value for a given action. diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index fee45a73..06aa92bf 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -396,6 +396,21 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg): keys = keys[: num + 1] + [k for k in keys[num + 1 :] if not k.startswith(key + ".")] num += 1 + @staticmethod + def delete_init_args_required_none(cfg_from, cfg_to): + for key, val in cfg_from.items(branches=True): + if isinstance(val, Namespace) and val.get("class_path") and val.get("init_args"): + skip_keys = [ + k + for k, v in val.init_args.__dict__.items() + if v is None and cfg_to.get(f"{key}.init_args.{k}") is not None + ] + if skip_keys: + parser = ActionTypeHint.get_class_parser(val.class_path) + for skip_key in skip_keys: + if skip_key in parser.required_args: + del val.init_args[skip_key] + @staticmethod @contextmanager def subclass_arg_context(parser): diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index b4d237d6..be5161fe 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -816,9 +816,10 @@ def __init__(self, optimizer: Optimizer, last_epoch: int = -1): class ReduceLROnPlateau: - def __init__(self, optimizer: Optimizer, monitor: str): + def __init__(self, optimizer: Optimizer, monitor: str, factor: float = 0.1): self.optimizer = optimizer self.monitor = monitor + self.factor = factor def test_callable_args_return_type_union_of_classes(parser, subtests): @@ -846,7 +847,7 @@ def test_callable_args_return_type_union_of_classes(parser, subtests): } cfg = parser.parse_args([f"--scheduler={value}"]) assert f"{__name__}.ReduceLROnPlateau" == cfg.scheduler.class_path - assert Namespace(monitor="loss") == cfg.scheduler.init_args + assert Namespace(monitor="loss", factor=0.1) == cfg.scheduler.init_args init = parser.instantiate_classes(cfg) scheduler = init.scheduler(optimizer) assert isinstance(scheduler, ReduceLROnPlateau) @@ -948,6 +949,40 @@ def test_callable_zero_args_return_type_class(parser): assert activation.negative_slope == 0.05 +class ModelRequiredCallableArg: + def __init__( + self, + scheduler: Callable[[Optimizer], ReduceLROnPlateau] = lambda o: ReduceLROnPlateau(o, monitor="acc"), + ): + self.scheduler = scheduler + + +def test_callable_return_class_required_arg_from_default(parser): + parser.add_argument("--cfg", action="config") + parser.add_argument("--model", type=ModelRequiredCallableArg) + + cfg = parser.parse_args(["--model=ModelRequiredCallableArg"]) + assert cfg.model.init_args.scheduler.class_path == f"{__name__}.ReduceLROnPlateau" + assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.1) + + config = { + "model": { + "class_path": f"{__name__}.ModelRequiredCallableArg", + "init_args": { + "scheduler": { + "class_path": f"{__name__}.ReduceLROnPlateau", + "init_args": { + "factor": 0.5, + }, + }, + }, + } + } + cfg = parser.parse_args([f"--cfg={config}"]) + assert cfg.model.init_args.scheduler.class_path == f"{__name__}.ReduceLROnPlateau" + assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.5) + + # lazy_instance tests