diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7b47285e..d07f09b4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -34,9 +34,19 @@ Fixed - ``default_config_files`` making parse fail for subcommands and nested subclass types (`lightning-forums#5963 `__). -- Import path of inherited classmethod not resolving correctly (`lightning#19863 - comment - `__). +- Fixes related to transformers ``PreTrainedModel.from_pretrained`` + (`lightning#19863 comment + `__). + - Import path of inherited classmethod not resolving correctly (`#548 + `__). + - Resolved parameters leading to multiple values for keyword argument (`#551 + `__). + - Function with return type a class in ``class_path`` in some cases fails with + unexpected ``instantiate`` parameter error (`#551 + `__). + - Ignore incorrectly resolved ``config_file_name`` parameter for transformers + model ``from_pretrained``(`#551 + `__). v4.31.0 (2024-06-27) diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py index 7ef82c3c..29712ead 100644 --- a/jsonargparse/_parameter_resolvers.py +++ b/jsonargparse/_parameter_resolvers.py @@ -11,7 +11,7 @@ from functools import partial from importlib import import_module from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from ._common import ( LoggerProperty, @@ -52,6 +52,10 @@ class ParamData: ast_assign_type: Tuple[Type[ast.AST], ...] = (ast.AnnAssign, ast.Assign) param_kwargs_pop_or_get = "**.pop|get():" +ignore_params = { + "transformers.BertModel.from_pretrained": {"config_file_name"}, +} + class SourceNotAvailable(Exception): "Raised when the source code for some component is not available." @@ -268,11 +272,14 @@ def ast_get_call_keyword_names(node): return [kw_node.arg for kw_node in node.keywords if kw_node.arg] -def remove_given_parameters(node, params): +def remove_given_parameters(node, params, removed_params: Optional[set] = None): given_args = set(ast_get_call_positional_indexes(node)) given_kwargs = set(ast_get_call_keyword_names(node)) + input_params = params params = [p for n, p in enumerate(params) if n not in given_args] params = [p for p in params if p.name not in given_kwargs] + if removed_params is not None and len(params) < len(input_params): + removed_params.update(p.name for p in input_params if p.name in given_kwargs) return params @@ -745,6 +752,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]: return [], [] params_list = [] + removed_params: Set[str] = set() kwargs_value = kwargs_name and values_to_find[kwargs_name] kwargs_value_dump = kwargs_value and ast.dump(kwargs_value) for node, source in [(v, s) for k, v, s in values_found if k == kwargs_name]: @@ -768,7 +776,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]: get_param_args = self.get_node_component(node, source) if get_param_args: params = get_signature_parameters(*get_param_args, logger=self.logger) - params = remove_given_parameters(node, params) + params = remove_given_parameters(node, params, removed_params) if params: self.add_node_origins(params, node) params_list.append(params) @@ -783,6 +791,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]: self.log_debug(f"unsupported type of assign: {ast_str(node)}") params = group_parameters(params_list) + params = [p for p in params if p.name not in removed_params] return split_args_and_kwargs(params) def get_parameters_attr_use_in_members(self, attr_name) -> ParamList: @@ -826,6 +835,12 @@ def get_parameters_call_attr(self, attr_name: str, attr_value: ast.AST) -> Optio matched = group_parameters(matched) return matched or None + def remove_ignore_parameters(self, params: ParamList) -> ParamList: + import_path = get_import_path(self.component) + if import_path in ignore_params: + params = [p for p in params if p.name not in ignore_params[import_path]] + return params + def get_parameters(self) -> ParamList: if self.component is None: return [] @@ -839,6 +854,7 @@ def get_parameters(self) -> ParamList: args, kwargs = self.get_parameters_args_and_kwargs() params = replace_args_and_kwargs(params, args, kwargs) add_stub_types(stubs, params, self.component) + params = self.remove_ignore_parameters(params) return params diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 19aeb9ec..aec416d1 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -620,6 +620,7 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0): if inspect.isclass(val_class): parser.add_class_arguments(val_class, **kwargs) else: + kwargs = {k: v for k, v in kwargs.items() if k != "instantiate"} parser.add_function_arguments(val_class, **kwargs) if "linked_targets" in kwargs and parser.required_args: diff --git a/jsonargparse_tests/test_parameter_resolvers.py b/jsonargparse_tests/test_parameter_resolvers.py index 9cc6f800..f894eaeb 100644 --- a/jsonargparse_tests/test_parameter_resolvers.py +++ b/jsonargparse_tests/test_parameter_resolvers.py @@ -549,10 +549,11 @@ def function_optional_callable(p1: Optional[Callable] = None, **kw): function_no_args_no_kwargs(**kw) -def assert_params(params, expected, origins={}): +def assert_params(params, expected, origins={}, help=True): assert expected == [p.name for p in params] - docs = [f"help for {p.name}" for p in params] if docstring_parser_support else [None] * len(params) - assert docs == [p.doc for p in params] + if help: + docs = [f"help for {p.name}" for p in params] if docstring_parser_support else [None] * len(params) + assert docs == [p.doc for p in params] assert all(isinstance(params[n].default, ConditionalDefault) for n in origins.keys()) param_origins = { n: [o.split(f"{__name__}.", 1)[1] for o in p.origin] for n, p in enumerate(params) if p.origin is not None @@ -876,6 +877,28 @@ def test_get_params_optional_callable(): assert_params(get_params(function_optional_callable), ["p1", "pk1", "k2"]) +def func_several_params(p1: int = 1, p2: int = 2, p3: int = 3, p4: int = 4): + pass + + +def func_given_kwargs(p: int, **kwargs): + func_several_params(p2=0, **kwargs) + func_several_params(p4=0, **kwargs) + + +def test_get_params_given_kwargs(): + assert_params(get_params(func_given_kwargs), ["p", "p1", "p3"], help=False) + + +def test_get_params_some_ignored(): + with patch.dict( + "jsonargparse._parameter_resolvers.ignore_params", {f"{__name__}.func_several_params": {"p2", "p3"}} + ): + assert_params(get_params(func_several_params), ["p1", "p4"], help=False) + with patch.dict("jsonargparse._parameter_resolvers.ignore_params", {f"{__name__}.func_given_kwargs": {"p3"}}): + assert_params(get_params(func_given_kwargs), ["p", "p1"], help=False) + + # unsupported cases