Skip to content

Commit

Permalink
Fix custom instantiators not working for nested dependency injection (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Oct 29, 2024
1 parent 0dd8ea8 commit fd359b5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Fixed
<https://github.com/omni-us/jsonargparse/pull/597>`__).
- Callables that return class not considering previous values (`#603
<https://github.com/omni-us/jsonargparse/pull/603>`__).
- Custom instantiators not working for nested dependency injection (`#608
<https://github.com/omni-us/jsonargparse/pull/608>`__).

Changed
^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion jsonargparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, class_type: Type[ClassType], *args, **kwargs) -> ClassType:
defaults_cache: ContextVar[Optional[Namespace]] = ContextVar("defaults_cache", default=None)
lenient_check: ContextVar[Union[bool, str]] = ContextVar("lenient_check", default=False)
load_value_mode: ContextVar[Optional[str]] = ContextVar("load_value_mode", default=None)
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators")
class_instantiators: ContextVar[Optional[InstantiatorsDictType]] = ContextVar("class_instantiators", default=None)
nested_links: ContextVar[List[dict]] = ContextVar("nested_links", default=[])


Expand Down
5 changes: 5 additions & 0 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ._common import (
InstantiatorCallable,
InstantiatorsDictType,
class_instantiators,
debug_mode_active,
is_dataclass_like,
lenient_check,
Expand Down Expand Up @@ -1139,6 +1140,10 @@ def _get_instantiators(self):
parent_instantiators = self.parent_parser._get_instantiators()
instantiators = instantiators.copy()
instantiators.update({k: v for k, v in parent_instantiators.items() if k not in instantiators})
context_instantiators = class_instantiators.get()
if context_instantiators:
instantiators = instantiators.copy()
instantiators.update({k: v for k, v in context_instantiators.items() if k not in instantiators})
return instantiators

def instantiate_classes(
Expand Down
15 changes: 15 additions & 0 deletions jsonargparse_tests/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,21 @@ def test_custom_instantiation_replace(parser):
assert list(parser._instantiators.values())[0] is second_instantiator


class CustomInstantiationNested:
def __init__(self, sub: CustomInstantiationBase):
self.sub = sub


def test_custom_instantiation_nested(parser):
parser.add_argument("--cls", type=CustomInstantiationNested)
parser.add_instantiator(instantiator("nested"), CustomInstantiationBase, subclasses=True)
cfg = parser.parse_args(["--cls=CustomInstantiationNested", "--cls.sub=CustomInstantiationSub"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.cls, CustomInstantiationNested)
assert isinstance(init.cls.sub, CustomInstantiationSub)
assert init.cls.sub.call == "nested"


# environment tests


Expand Down

0 comments on commit fd359b5

Please sign in to comment.