diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c5db0c5f..489125ef 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -29,6 +29,8 @@ Fixed `__). - Callables that return class not considering previous values (`#603 `__). +- Custom instantiators not working for nested dependency injection (`#608 + `__). Changed ^^^^^^^ diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py index f4e99bad..9af90e53 100644 --- a/jsonargparse/_common.py +++ b/jsonargparse/_common.py @@ -59,7 +59,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType: defaults_cache: ContextVar[Optional[Namespace]] = ContextVar("defaults_cache", default=None) lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False) load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None) -class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators") +class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators", default=None) nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[]) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index ce89003e..86a575a0 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -39,6 +39,7 @@ from ._common import ( InstantiatorCallable, InstantiatorsDictType, + class_instantiators, debug_mode_active, is_dataclass_like, lenient_check, @@ -1139,6 +1140,10 @@ def _get_instantiators(self): parent_instantiators = self.parent_parser._get_instantiators() instantiators = instantiators.copy() instantiators.update({k: v for k, v in parent_instantiators.items() if k not in instantiators}) + context_instantiators = class_instantiators.get() + if context_instantiators: + instantiators = instantiators.copy() + instantiators.update({k: v for k, v in context_instantiators.items() if k not in instantiators}) return instantiators def instantiate_classes( diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index f85fe50f..e433fc2a 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -479,6 +479,21 @@ def test_custom_instantiation_replace(parser): assert list(parser._instantiators.values())[0] is second_instantiator +class CustomInstantiationNested: + def __init__(self, sub: CustomInstantiationBase): + self.sub = sub + + +def test_custom_instantiation_nested(parser): + parser.add_argument("--cls", type=CustomInstantiationNested) + parser.add_instantiator(instantiator("nested"), CustomInstantiationBase, subclasses=True) + cfg = parser.parse_args(["--cls=CustomInstantiationNested", "--cls.sub=CustomInstantiationSub"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.cls, CustomInstantiationNested) + assert isinstance(init.cls.sub, CustomInstantiationSub) + assert init.cls.sub.call == "nested" + + # environment tests