From e57a370f902991c5cd85c9021ff614322ebbeaab Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 29 Oct 2024 05:41:56 +0100 Subject: [PATCH 1/2] Fix custom instantiators not working for nested dependency injection (#606). --- CHANGELOG.rst | 2 ++ jsonargparse/_common.py | 2 +- jsonargparse/_core.py | 5 +++++ jsonargparse_tests/test_subclasses.py | 15 +++++++++++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c5db0c5f..f00d7950 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -29,6 +29,8 @@ Fixed `__). - Callables that return class not considering previous values (`#603 `__). +- Custom instantiators not working for nested dependency injection (`#??? + `__). Changed ^^^^^^^ diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py index f4e99bad..9af90e53 100644 --- a/jsonargparse/_common.py +++ b/jsonargparse/_common.py @@ -59,7 +59,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType: defaults_cache: ContextVar[Optional[Namespace]] = ContextVar("defaults_cache", default=None) lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False) load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None) -class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators") +class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators", default=None) nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[]) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index ce89003e..86a575a0 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -39,6 +39,7 @@ from ._common import ( InstantiatorCallable, InstantiatorsDictType, + class_instantiators, debug_mode_active, is_dataclass_like, lenient_check, @@ -1139,6 +1140,10 @@ def _get_instantiators(self): parent_instantiators = self.parent_parser._get_instantiators() instantiators = instantiators.copy() instantiators.update({k: v for k, v in parent_instantiators.items() if k not in instantiators}) + context_instantiators = class_instantiators.get() + if context_instantiators: + instantiators = instantiators.copy() + instantiators.update({k: v for k, v in context_instantiators.items() if k not in instantiators}) return instantiators def instantiate_classes( diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index f85fe50f..e433fc2a 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -479,6 +479,21 @@ def test_custom_instantiation_replace(parser): assert list(parser._instantiators.values())[0] is second_instantiator +class CustomInstantiationNested: + def __init__(self, sub: CustomInstantiationBase): + self.sub = sub + + +def test_custom_instantiation_nested(parser): + parser.add_argument("--cls", type=CustomInstantiationNested) + parser.add_instantiator(instantiator("nested"), CustomInstantiationBase, subclasses=True) + cfg = parser.parse_args(["--cls=CustomInstantiationNested", "--cls.sub=CustomInstantiationSub"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.cls, CustomInstantiationNested) + assert isinstance(init.cls.sub, CustomInstantiationSub) + assert init.cls.sub.call == "nested" + + # environment tests From 480c8bd83568242b4441140f4af5a46343a23fa1 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 29 Oct 2024 05:49:11 +0100 Subject: [PATCH 2/2] Update changelog --- CHANGELOG.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f00d7950..489125ef 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -29,8 +29,8 @@ Fixed `__). - Callables that return class not considering previous values (`#603 `__). -- Custom instantiators not working for nested dependency injection (`#??? - `__). +- Custom instantiators not working for nested dependency injection (`#608 + `__). Changed ^^^^^^^