From 605379f2b2eb2b5ca8db7edf1ffd99a215275fa3 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 10 Apr 2024 07:34:15 +0200 Subject: [PATCH 1/2] Fix lazy_instance not working for callable classes (#473). --- CHANGELOG.rst | 6 ++++++ jsonargparse/_typehints.py | 12 +++++++++--- jsonargparse_tests/test_typehints.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c36826e7..50297a40 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,12 @@ Added - Support for "-" as value for Path class initialization so that user can ask to use standard input/output instead of file. +Fixed +^^^^^ +- ``lazy_instance`` not working for callable classes (`#473 comment + `__). + + v4.27.7 (2024-03-21) -------------------- diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 1c87980a..69a8c200 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -1292,6 +1292,7 @@ class LazyInitBaseClass: def __init__(self, class_type: Type, lazy_kwargs: dict): assert not issubclass(class_type, LazyInitBaseClass) check_lazy_kwargs(class_type, lazy_kwargs) + self._lazy = type(self) self._lazy_class_type = class_type self._lazy_kwargs = lazy_kwargs self._lazy_methods = {} @@ -1305,17 +1306,22 @@ def __init__(self, class_type: Type, lazy_kwargs: dict): if id(member) in seen_methods: self.__dict__[name] = seen_methods[id(member)] else: - self.__dict__[name] = partial(self._lazy_init_then_call_method, name) - seen_methods[id(member)] = self.__dict__[name] + lazy_method = partial(self._lazy_init_then_call_method, name) + self.__dict__[name] = lazy_method + if name == "__call__": + self._lazy.__call__ = lazy_method # type: ignore[method-assign] + seen_methods[id(member)] = lazy_method def _lazy_init(self): for name in self._lazy_methods: + if name == "__call__": + self._lazy.__call__ = self._lazy_methods[name] del self.__dict__[name] super().__init__(**self._lazy_kwargs) def _lazy_init_then_call_method(self, method_name, *args, **kwargs): self._lazy_init() - return getattr(self, method_name)(*args, **kwargs) + return self._lazy_methods[method_name](*args, **kwargs) def lazy_get_init_args(self) -> Namespace: return Namespace(self._lazy_kwargs) diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 26f9c7df..fe623628 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -884,6 +884,21 @@ def test_lazy_instance_pickleable(): assert reloaded.lazy_get_init_data() == instance1.lazy_get_init_data() +class OptimizerCallable: + def __init__(self, lr: float = 0.1): + self.lr = lr + + def __call__(self, params) -> SGD: + return SGD(params, lr=self.lr) + + +def test_lazy_instance_callable(): + lazy_optimizer = lazy_instance(OptimizerCallable, lr=0.2) + optimizer = lazy_optimizer([1, 2]) + assert optimizer.lr == 0.2 + assert optimizer.params == [1, 2] + + # other tests From 55b37c761b1bc24893d9b5b7c59a24fb1ef1465f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Mon, 15 Apr 2024 09:05:21 +0200 Subject: [PATCH 2/2] Update changelog. --- CHANGELOG.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8ee75508..49516911 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -29,10 +29,6 @@ Fixed - Account for breaking change in ``argparse.ArgumentParser._parse_optional`` affecting python ``3.11.9`` and likely ``>3.13`` (`#484 `__). - - -Fixed -^^^^^ - ``lazy_instance`` not working for callable classes (`#473 comment `__).