Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: --print_shtab crashing on failure to get signature parameters from one class #537

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ Fixed
^^^^^
- Resolving of import paths for some ``torch`` functions not working (`#535
<https://github.com/omni-us/jsonargparse/pull/535>`__).
- ``--print_shtab`` crashing on failure to get signature parameters from one
class (`lightning#10858 comment
<https://github.com/Lightning-AI/pytorch-lightning/discussions/10858#discussioncomment-9846252>`__).

Changed
^^^^^^^
Expand Down
18 changes: 11 additions & 7 deletions jsonargparse/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
)
from ._util import NoneType, Path, import_object, unique

shtab_shell: ContextVar = ContextVar("shtab_shell")
shtab_prog: ContextVar = ContextVar("shtab_prog")
shtab_preambles: ContextVar = ContextVar("shtab_preambles")


def handle_completions(parser):
if find_spec("argcomplete") and "_ARGCOMPLETE" in os.environ:
Expand Down Expand Up @@ -76,6 +72,10 @@ def argcomplete_warn_redraw_prompt(prefix, message):

# shtab

shtab_shell: ContextVar = ContextVar("shtab_shell")
shtab_prog: ContextVar = ContextVar("shtab_prog")
shtab_preambles: ContextVar = ContextVar("shtab_preambles")


class ShtabAction(argparse.Action):
def __init__(
Expand Down Expand Up @@ -236,7 +236,7 @@ def get_typehint_choices(typehint, prefix, parser, skip, choices=None, added_sub
origin = get_typehint_origin(typehint)
if origin == Union:
for subtype in typehint.__args__:
if subtype in added_subclasses:
if subtype in added_subclasses or subtype is object:
continue
get_typehint_choices(subtype, prefix, parser, skip, choices, added_subclasses)
elif ActionTypeHint.is_subclass_typehint(typehint):
Expand All @@ -261,8 +261,12 @@ def add_subactions_and_get_subclass_choices(typehint, prefix, parser, skip, adde
subclasses = defaultdict(list)
for path in paths:
choices.append(path)
cls = import_object(path)
params = get_signature_parameters(cls)
try:
cls = import_object(path)
params = get_signature_parameters(cls, None, parser._logger)
except Exception as ex:
parser._logger.debug(f"Unable to get signature parameters for '{path}': {ex}")
continue
num_skip = next((s for s in skip if isinstance(s, int)), 0)
if num_skip > 0:
params = params[num_skip:]
Expand Down
19 changes: 18 additions & 1 deletion jsonargparse_tests/test_shtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

from jsonargparse import ArgumentParser
from jsonargparse._completions import norm_name
from jsonargparse._parameter_resolvers import get_signature_parameters
from jsonargparse._typehints import type_to_str
from jsonargparse.typing import Path_drw, Path_fr
from jsonargparse_tests.conftest import get_parse_args_stdout
from jsonargparse_tests.conftest import capture_logs, get_parse_args_stdout


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -203,6 +204,22 @@ def __init__(self, p1: int, p3: float):
pass


def test_bash_subclasses_fail_get_perams(parser, logger):
def get_params_patch(cls, method, logger):
if cls == SubB:
raise Exception("test get params failure")
return get_signature_parameters(cls, method, logger)

parser.logger = logger
parser.add_argument("--cls", type=Base)
with capture_logs(logger) as logs, patch("jsonargparse._completions.get_signature_parameters", get_params_patch):
shtab_script = get_shtab_script(parser, "bash")
assert "'--cls' '--cls.p1' '--cls.p2'" in shtab_script
assert f"'{__name__}.SubB'" in shtab_script
assert "'--cls.p3'" not in shtab_script
assert "test_shtab.SubB': test get params failure" in logs.getvalue()


def test_bash_subclasses_help(parser):
parser.add_argument("--cls", type=Base)
shtab_script = get_shtab_script(parser, "bash")
Expand Down