From 090685d92692e0023079580375fc61576628551f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 25 Jun 2024 05:44:55 +0200 Subject: [PATCH] Fix: --print_shtab crashing on failure to get signature parameters from one class (lightning#10858 comment). --- CHANGELOG.rst | 3 +++ jsonargparse/_completions.py | 18 +++++++++++------- jsonargparse_tests/test_shtab.py | 19 ++++++++++++++++++- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 207108c4..e62c4338 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,9 @@ Fixed ^^^^^ - Resolving of import paths for some ``torch`` functions not working (`#535 `__). +- ``--print_shtab`` crashing on failure to get signature parameters from one + class (`lightning#10858 comment + `__). Changed ^^^^^^^ diff --git a/jsonargparse/_completions.py b/jsonargparse/_completions.py index c693916a..ee722e5a 100644 --- a/jsonargparse/_completions.py +++ b/jsonargparse/_completions.py @@ -25,10 +25,6 @@ ) from ._util import NoneType, Path, import_object, unique -shtab_shell: ContextVar = ContextVar("shtab_shell") -shtab_prog: ContextVar = ContextVar("shtab_prog") -shtab_preambles: ContextVar = ContextVar("shtab_preambles") - def handle_completions(parser): if find_spec("argcomplete") and "_ARGCOMPLETE" in os.environ: @@ -76,6 +72,10 @@ def argcomplete_warn_redraw_prompt(prefix, message): # shtab +shtab_shell: ContextVar = ContextVar("shtab_shell") +shtab_prog: ContextVar = ContextVar("shtab_prog") +shtab_preambles: ContextVar = ContextVar("shtab_preambles") + class ShtabAction(argparse.Action): def __init__( @@ -236,7 +236,7 @@ def get_typehint_choices(typehint, prefix, parser, skip, choices=None, added_sub origin = get_typehint_origin(typehint) if origin == Union: for subtype in typehint.__args__: - if subtype in added_subclasses: + if subtype in added_subclasses or subtype is object: continue get_typehint_choices(subtype, prefix, parser, skip, choices, added_subclasses) elif ActionTypeHint.is_subclass_typehint(typehint): @@ -261,8 +261,12 @@ def add_subactions_and_get_subclass_choices(typehint, prefix, parser, skip, adde subclasses = defaultdict(list) for path in paths: choices.append(path) - cls = import_object(path) - params = get_signature_parameters(cls) + try: + cls = import_object(path) + params = get_signature_parameters(cls, None, parser._logger) + except Exception as ex: + parser._logger.debug(f"Unable to get signature parameters for '{path}': {ex}") + continue num_skip = next((s for s in skip if isinstance(s, int)), 0) if num_skip > 0: params = params[num_skip:] diff --git a/jsonargparse_tests/test_shtab.py b/jsonargparse_tests/test_shtab.py index 83704e31..ae463a98 100644 --- a/jsonargparse_tests/test_shtab.py +++ b/jsonargparse_tests/test_shtab.py @@ -13,9 +13,10 @@ from jsonargparse import ArgumentParser from jsonargparse._completions import norm_name +from jsonargparse._parameter_resolvers import get_signature_parameters from jsonargparse._typehints import type_to_str from jsonargparse.typing import Path_drw, Path_fr -from jsonargparse_tests.conftest import get_parse_args_stdout +from jsonargparse_tests.conftest import capture_logs, get_parse_args_stdout @pytest.fixture(autouse=True) @@ -203,6 +204,22 @@ def __init__(self, p1: int, p3: float): pass +def test_bash_subclasses_fail_get_perams(parser, logger): + def get_params_patch(cls, method, logger): + if cls == SubB: + raise Exception("test get params failure") + return get_signature_parameters(cls, method, logger) + + parser.logger = logger + parser.add_argument("--cls", type=Base) + with capture_logs(logger) as logs, patch("jsonargparse._completions.get_signature_parameters", get_params_patch): + shtab_script = get_shtab_script(parser, "bash") + assert "'--cls' '--cls.p1' '--cls.p2'" in shtab_script + assert f"'{__name__}.SubB'" in shtab_script + assert "'--cls.p3'" not in shtab_script + assert "test_shtab.SubB': test get params failure" in logs.getvalue() + + def test_bash_subclasses_help(parser): parser.add_argument("--cls", type=Base) shtab_script = get_shtab_script(parser, "bash")