Skip to content

Commit

Permalink
Fix: Add function signature failing when conditionally calling differ…
Browse files Browse the repository at this point in the history
…ent functions (#467).
  • Loading branch information
mauvilsa committed Mar 14, 2024
1 parent bca1588 commit 8310ff3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Fixed
produces an invalid string default.
- dataclass single parameter change incorrectly resetting previous values (`#464
<https://github.com/omni-us/jsonargparse/issues/464>`__).
- Add function signature failing when conditionally calling different functions
(`#467 <https://github.com/omni-us/jsonargparse/issues/467>`__).


v4.27.5 (2024-02-12)
Expand Down
21 changes: 12 additions & 9 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dataclasses
import inspect
import re
from argparse import SUPPRESS
from argparse import SUPPRESS, ArgumentParser
from contextlib import suppress
from typing import Any, Callable, List, Optional, Set, Tuple, Type, Union

Expand Down Expand Up @@ -255,7 +255,7 @@ def _add_signature_arguments(
## Create group if requested ##
doc_group = get_doc_short_description(function_or_class, method_name, logger=self.logger)
component = getattr(function_or_class, method_name) if method_name else function_or_class
group = self._create_group_if_requested(
container = self._create_group_if_requested(
component,
nested_key,
as_group,
Expand All @@ -268,7 +268,7 @@ def _add_signature_arguments(
added_args: List[str] = []
for param in params:
self._add_signature_parameter(
group,
container,
nested_key,
param,
added_args,
Expand All @@ -283,7 +283,7 @@ def _add_signature_arguments(

def _add_signature_parameter(
self,
group,
container,
nested_key: Optional[str],
param,
added_args: List[str],
Expand Down Expand Up @@ -339,11 +339,14 @@ def _add_signature_parameter(
dest = (nested_key + "." if nested_key else "") + name
args = [dest if is_required and as_positional else "--" + dest]
if param.origin:
parser = container
if not isinstance(container, ArgumentParser):
parser = getattr(container, "parser")
group_name = "; ".join(str(o) for o in param.origin)
if group_name in group.parser.groups:
group = group.parser.groups[group_name]
if group_name in parser.groups:
container = parser.groups[group_name]
else:
group = group.parser.add_argument_group(
container = parser.add_argument_group(
f"Conditional arguments [origins: {group_name}]",
name=group_name,
)
Expand Down Expand Up @@ -372,7 +375,7 @@ def _add_signature_parameter(
args=args,
kwargs=kwargs,
enable_path=enable_path,
container=group,
container=container,
logger=self.logger,
sub_add_kwargs=sub_add_kwargs,
)
Expand All @@ -387,7 +390,7 @@ def _add_signature_parameter(
if is_dataclass_like_typehint:
kwargs.update(sub_add_kwargs)
with ActionTypeHint.allow_default_instance_context():
action = group.add_argument(*args, **kwargs)
action = container.add_argument(*args, **kwargs)
action.sub_add_kwargs = sub_add_kwargs
if is_subclass_typehint and len(subclass_skip) > 0:
action.sub_add_kwargs["skip"] = subclass_skip
Expand Down
26 changes: 26 additions & 0 deletions jsonargparse_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from jsonargparse import CLI, capture_parser, lazy_instance
from jsonargparse._optionals import docstring_parser_support, ruyaml_support
from jsonargparse._typehints import Literal
from jsonargparse.typing import final
from jsonargparse_tests.conftest import skip_if_docstring_parser_unavailable

Expand Down Expand Up @@ -120,6 +121,31 @@ def test_multiple_functions_subcommand_help():
assert "--a2 A2" in out


def conditionalA(foo: int = 1):
return foo


def conditionalB(bar: int = 2):
return bar


def conditional_function(fn: "Literal['A', 'B']", *args, **kwargs):
if fn == "A":
return conditionalA(*args, **kwargs)
elif fn == "B":
return conditionalB(*args, **kwargs)
raise NotImplementedError(fn)


@pytest.mark.skipif(condition=sys.version_info < (3, 9), reason="python>=3.9 is required")
@pytest.mark.skipif(condition=not Literal, reason="Literal is required")
def test_literal_conditional_function():
out = get_cli_stdout(conditional_function, args=["--help"])
assert "Conditional arguments" in out
assert "--foo FOO (type: int, default: Conditional<ast-resolver> {1, NOT_ACCEPTED})" in out
assert "--bar BAR (type: int, default: Conditional<ast-resolver> {2, NOT_ACCEPTED})" in out


# single class tests


Expand Down

0 comments on commit 8310ff3

Please sign in to comment.