From fcf22e85bcc3f7e0c76b5dfc8eb088e6ba7a26a3 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 20 May 2024 07:25:03 +0200 Subject: [PATCH] Fix not able to modify init args for callable with class return and default class. --- CHANGELOG.rst | 2 ++ jsonargparse/_typehints.py | 2 ++ jsonargparse_tests/test_typehints.py | 7 +++++++ 3 files changed, 11 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a88cb63b..6635a646 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,8 @@ Fixed space (`#499 `__). - ``format_usage()`` not working (`#501 `__). +- Not able to modify init args for callable with class return and default class + (`#5?? `__). v4.28.0 (2024-04-17) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index f945b6d9..c2ad0ba9 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -980,6 +980,8 @@ def subclass_spec_as_namespace(val, prev_val=None): val = Namespace({root_key: val}) if isinstance(prev_val, str): prev_val = Namespace(class_path=prev_val) + elif inspect.isclass(prev_val): + prev_val = Namespace(class_path=get_import_path(prev_val)) if isinstance(val, dict): val = Namespace(val) if "init_args" in val and isinstance(val["init_args"], dict): diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index f5d284e6..5e6ce9cb 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -792,6 +792,13 @@ def test_callable_multiple_args_return_type_class(parser, subtests): assert f"{__name__}.{name}" in help_str +def test_callable_return_class_default_class_override_init_arg(parser): + parser.add_argument("--optimizer", type=Callable[[List[float]], Optimizer], default=SGD) + cfg = parser.parse_args(["--optimizer.momentum=0.5", "--optimizer.lr=0.05"]) + assert cfg.optimizer.class_path == f"{__name__}.SGD" + assert cfg.optimizer.init_args == Namespace(lr=0.05, momentum=0.5) + + class StepLR: def __init__(self, optimizer: Optimizer, last_epoch: int = -1): self.optimizer = optimizer