Skip to content

Commit

Permalink
Fixes related to transformers PreTrainedModel.from_pretrained (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvilsa authored Jul 17, 2024
1 parent 45044f2 commit b8d6e58
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 9 deletions.
16 changes: 13 additions & 3 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ Fixed
- ``default_config_files`` making parse fail for subcommands and nested subclass
types (`lightning-forums#5963
<https://lightning.ai/forums/t/problem-lightningcli-with-default-config-files/5963>`__).
- Import path of inherited classmethod not resolving correctly (`lightning#19863
comment
<https://github.com/Lightning-AI/pytorch-lightning/discussions/19863#discussioncomment-10010226>`__).
- Fixes related to transformers ``PreTrainedModel.from_pretrained``
(`lightning#19863 comment
<https://github.com/Lightning-AI/pytorch-lightning/discussions/19863#discussioncomment-9821765>`__).
- Import path of inherited classmethod not resolving correctly (`#548
<https://github.com/omni-us/jsonargparse/pull/548>`__).
- Resolved parameters leading to multiple values for keyword argument (`#551
<https://github.com/omni-us/jsonargparse/pull/551>`__).
- Function with return type a class in ``class_path`` in some cases fails with
unexpected ``instantiate`` parameter error (`#551
<https://github.com/omni-us/jsonargparse/pull/551>`__).
- Ignore incorrectly resolved ``config_file_name`` parameter for transformers
model ``from_pretrained``(`#551
<https://github.com/omni-us/jsonargparse/pull/551>`__).
v4.31.0 (2024-06-27)
Expand Down
22 changes: 19 additions & 3 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import partial
from importlib import import_module
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

from ._common import (
LoggerProperty,
Expand Down Expand Up @@ -52,6 +52,10 @@ class ParamData:
ast_assign_type: Tuple[Type[ast.AST], ...] = (ast.AnnAssign, ast.Assign)
param_kwargs_pop_or_get = "**.pop|get():"

ignore_params = {
"transformers.BertModel.from_pretrained": {"config_file_name"},
}


class SourceNotAvailable(Exception):
"Raised when the source code for some component is not available."
Expand Down Expand Up @@ -268,11 +272,14 @@ def ast_get_call_keyword_names(node):
return [kw_node.arg for kw_node in node.keywords if kw_node.arg]


def remove_given_parameters(node, params):
def remove_given_parameters(node, params, removed_params: Optional[set] = None):
given_args = set(ast_get_call_positional_indexes(node))
given_kwargs = set(ast_get_call_keyword_names(node))
input_params = params
params = [p for n, p in enumerate(params) if n not in given_args]
params = [p for p in params if p.name not in given_kwargs]
if removed_params is not None and len(params) < len(input_params):
removed_params.update(p.name for p in input_params if p.name in given_kwargs)
return params


Expand Down Expand Up @@ -745,6 +752,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
return [], []

params_list = []
removed_params: Set[str] = set()
kwargs_value = kwargs_name and values_to_find[kwargs_name]
kwargs_value_dump = kwargs_value and ast.dump(kwargs_value)
for node, source in [(v, s) for k, v, s in values_found if k == kwargs_name]:
Expand All @@ -768,7 +776,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
get_param_args = self.get_node_component(node, source)
if get_param_args:
params = get_signature_parameters(*get_param_args, logger=self.logger)
params = remove_given_parameters(node, params)
params = remove_given_parameters(node, params, removed_params)
if params:
self.add_node_origins(params, node)
params_list.append(params)
Expand All @@ -783,6 +791,7 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
self.log_debug(f"unsupported type of assign: {ast_str(node)}")

params = group_parameters(params_list)
params = [p for p in params if p.name not in removed_params]
return split_args_and_kwargs(params)

def get_parameters_attr_use_in_members(self, attr_name) -> ParamList:
Expand Down Expand Up @@ -826,6 +835,12 @@ def get_parameters_call_attr(self, attr_name: str, attr_value: ast.AST) -> Optio
matched = group_parameters(matched)
return matched or None

def remove_ignore_parameters(self, params: ParamList) -> ParamList:
import_path = get_import_path(self.component)
if import_path in ignore_params:
params = [p for p in params if p.name not in ignore_params[import_path]]
return params

def get_parameters(self) -> ParamList:
if self.component is None:
return []
Expand All @@ -839,6 +854,7 @@ def get_parameters(self) -> ParamList:
args, kwargs = self.get_parameters_args_and_kwargs()
params = replace_args_and_kwargs(params, args, kwargs)
add_stub_types(stubs, params, self.component)
params = self.remove_ignore_parameters(params)
return params


Expand Down
1 change: 1 addition & 0 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
if inspect.isclass(val_class):
parser.add_class_arguments(val_class, **kwargs)
else:
kwargs = {k: v for k, v in kwargs.items() if k != "instantiate"}
parser.add_function_arguments(val_class, **kwargs)

if "linked_targets" in kwargs and parser.required_args:
Expand Down
29 changes: 26 additions & 3 deletions jsonargparse_tests/test_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,11 @@ def function_optional_callable(p1: Optional[Callable] = None, **kw):
function_no_args_no_kwargs(**kw)


def assert_params(params, expected, origins={}):
def assert_params(params, expected, origins={}, help=True):
assert expected == [p.name for p in params]
docs = [f"help for {p.name}" for p in params] if docstring_parser_support else [None] * len(params)
assert docs == [p.doc for p in params]
if help:
docs = [f"help for {p.name}" for p in params] if docstring_parser_support else [None] * len(params)
assert docs == [p.doc for p in params]
assert all(isinstance(params[n].default, ConditionalDefault) for n in origins.keys())
param_origins = {
n: [o.split(f"{__name__}.", 1)[1] for o in p.origin] for n, p in enumerate(params) if p.origin is not None
Expand Down Expand Up @@ -876,6 +877,28 @@ def test_get_params_optional_callable():
assert_params(get_params(function_optional_callable), ["p1", "pk1", "k2"])


def func_several_params(p1: int = 1, p2: int = 2, p3: int = 3, p4: int = 4):
pass


def func_given_kwargs(p: int, **kwargs):
func_several_params(p2=0, **kwargs)
func_several_params(p4=0, **kwargs)


def test_get_params_given_kwargs():
assert_params(get_params(func_given_kwargs), ["p", "p1", "p3"], help=False)


def test_get_params_some_ignored():
with patch.dict(
"jsonargparse._parameter_resolvers.ignore_params", {f"{__name__}.func_several_params": {"p2", "p3"}}
):
assert_params(get_params(func_several_params), ["p1", "p4"], help=False)
with patch.dict("jsonargparse._parameter_resolvers.ignore_params", {f"{__name__}.func_given_kwargs": {"p3"}}):
assert_params(get_params(func_given_kwargs), ["p", "p1"], help=False)


# unsupported cases


Expand Down

0 comments on commit b8d6e58

Please sign in to comment.