Skip to content

Commit

Permalink
- AST resolver now supports cls() class instantiation in classmethod #…
Browse files Browse the repository at this point in the history
…146.

- AST resolver now supports pop and get from **kwargs.
- Added AST resolver unit test using class as attribute of module.
  • Loading branch information
mauvilsa committed Jul 21, 2022
1 parent 1074755 commit 681c7e9
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 24 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ v4.12.0 (2022-07-??)
Added
^^^^^
- Instantiation links now support multiple sources.
- AST resolver now supports ``cls()`` class instantiation in ``classmethod``
`#146 <https://github.com/omni-us/jsonargparse/issues/146>`__.
- AST resolver now supports ``pop`` and ``get`` from ``**kwargs``.

Fixed
^^^^^
Expand All @@ -38,7 +41,7 @@ Changed
compute function.
- Instantiation links no longer restricted to first nesting level.
- AST parameter resolver now only logs debug messages instead of failing `#146
<https://github.com/omni-us/jsonargparse/pull/146>`__.
<https://github.com/omni-us/jsonargparse/issues/146>`__.
- Documented AST resolver support for ``**kwargs`` use in property.


Expand Down
18 changes: 13 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ unrelated to these variables.

class BaseClass: pass

**Cases for functions**
**Cases for statements in functions or methods**

.. testcode:: ast_resolver

Expand All @@ -1256,6 +1256,12 @@ unrelated to these variables.
def calls_a_class_method(*args, **kwargs):
SomeClass.a_class_method(*args, **kwargs)
def pops_from_kwargs(**kwargs):
val = kwargs.pop('name', 'default')
def gets_from_kwargs(**kwargs):
val = kwargs.get('name', 'default')
**Cases for classes**

.. testcode:: ast_resolver
Expand All @@ -1268,10 +1274,6 @@ unrelated to these variables.
def __init__(self, *args, **kwargs):
self.a_method(*args, **kwargs)
class CallCallable:
def __init__(self, *args, **kwargs):
a_callable(*args, **kwargs)
class AttributeUseInMethod:
def __init__(self, **kwargs):
self._kwargs = kwargs
Expand All @@ -1296,6 +1298,12 @@ unrelated to these variables.
def a_method(self):
a_callable(**self._kwargs)
class InstanceInClassmethod:
@classmethod
def get_instance(cls, **kwargs):
return cls(**kwargs)

There can be other parameters apart from ``*args`` and ``**kwargs``, thus in the
cases above the signatures can be for example like ``name(p1: int, k1: str =
'a', **kws)``. Also when internally calling some function or instantiating a
Expand Down
122 changes: 106 additions & 16 deletions jsonargparse/parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def is_method_or_property(attr) -> bool:
return is_method(attr) or is_property(attr)


def is_classmethod(parent, component) -> bool:
return parent and isinstance(inspect.getattr_static(parent, component.__name__), classmethod)


def ast_str(node):
return getattr(ast, 'unparse', ast.dump)(node)

Expand Down Expand Up @@ -86,16 +90,42 @@ def ast_is_dict_assign_with_value(node, value):
return False


def ast_is_call_with_value(node, value) -> bool:
if isinstance(node, ast.Call):
value_dump = ast.dump(value)
for argtype in ['args', 'keywords']:
for arg in getattr(node, argtype):
if isinstance(getattr(arg, 'value', None), ast.AST) and ast.dump(arg.value) == value_dump:
return True
def ast_is_call_with_value(node, value_dump) -> bool:
for argtype in ['args', 'keywords']:
for arg in getattr(node, argtype):
if isinstance(getattr(arg, 'value', None), ast.AST) and ast.dump(arg.value) == value_dump:
return True
return False


ast_constant_attr = {
ast.Constant: 'value',
# python <= 3.7:
ast.NameConstant: 'value',
ast.Num: 'n',
ast.Str: 's',
}


def ast_is_constant(node):
return isinstance(node, (ast.Str, ast.Num, ast.NameConstant, ast.Constant))


def ast_get_constant_value(node):
assert ast_is_constant(node)
return getattr(node, ast_constant_attr[node.__class__])


def ast_is_kwargs_pop_or_get(node, value_dump) -> bool:
return (
isinstance(node.func, ast.Attribute) and
value_dump == ast.dump(node.func.value) and
node.func.attr in {'pop', 'get'} and
len(node.args) == 2 and
isinstance(ast_get_constant_value(node.args[0]), str)
)


def ast_is_super_call(node) -> bool:
return (
isinstance(node, ast.Call) and
Expand Down Expand Up @@ -152,7 +182,10 @@ def get_arg_kind_index(params, kind):


def get_signature_parameters_and_indexes(component, parent, logger):
params = list(inspect.signature(component).parameters.values())
if is_classmethod(parent, component):
params = list(inspect.signature(component.__func__).parameters.values())
else:
params = list(inspect.signature(component).parameters.values())
if parent:
params = params[1:]
args_idx = get_arg_kind_index(params, kinds.VAR_POSITIONAL)
Expand All @@ -168,7 +201,35 @@ def get_signature_parameters_and_indexes(component, parent, logger):
component=component,
**{a: getattr(param, a) for a in parameter_attributes},
)
return params, args_idx, kwargs_idx
return params, args_idx, kwargs_idx, doc_params


ast_literals = {
ast.dump(ast.parse(v, mode='eval').body): lambda: ast.literal_eval(v)
for v in ['{}', '[]']
}


def get_kwargs_pop_or_get_parameter(node, component, parent, doc_params, logger):
name = ast_get_constant_value(node.args[0])
if ast_is_constant(node.args[1]):
default = ast_get_constant_value(node.args[1])
else:
default = ast.dump(node.args[1])
if default in ast_literals:
default = ast_literals[default]()
else:
default = None
logger.debug(f'Unsupported kwargs pop/get default: {ast_str(node)}')
return ParamData(
name=name,
annotation=inspect._empty,
default=default,
kind=inspect._ParameterKind.KEYWORD_ONLY,
doc=doc_params.get(name),
parent=parent,
component=component,
)


def split_args_and_kwargs(params: ParamList) -> Tuple[ParamList, ParamList]:
Expand Down Expand Up @@ -209,6 +270,13 @@ def common_parameters(params_list: List[ParamList]) -> ParamList:
return common


def merge_parameters(source: Union[ParamData, ParamList], target: ParamList) -> ParamList:
if not isinstance(source, list):
source = [source]
target_names = set(t.name for t in target)
return target + [s for s in source if s.name not in target_names]


def has_dunder_new_method(cls, attr_name):
classes = inspect.getmro(cls)[1:]
return (
Expand Down Expand Up @@ -271,6 +339,8 @@ def get_component_and_parent(
component = attr
elif is_property(attr):
component = attr.fget
elif isinstance(attr, classmethod):
component = getattr(function_or_class, method_or_property)
elif attr is not object.__init__:
raise ValueError(f'Invalid or unsupported input: class={function_or_class}, method_or_property={method_or_property}')
else:
Expand Down Expand Up @@ -329,13 +399,16 @@ def visit_AnnAssign(self, node):

def visit_Call(self, node):
for key, value in self.find_values.items():
if ast_is_call_with_value(node, value):
value_dump = ast.dump(value)
if ast_is_call_with_value(node, value_dump):
if isinstance(node.func, ast.Attribute):
value_dump = ast.dump(node.func.value)
if value_dump in self.dict_assigns:
self.values_found.append((key, self.dict_assigns[value_dump]))
continue
self.values_found.append((key, node))
elif ast_is_kwargs_pop_or_get(node, value_dump):
self.values_found.append((key, node))
self.generic_visit(node)

def find_values_usage(self, values):
Expand All @@ -349,14 +422,21 @@ def get_node_component(self, node) -> Optional[Tuple[Type, Optional[str]]]:
function_or_class = method_or_property = None
module = inspect.getmodule(self.component)
if isinstance(node.func, ast.Name):
function_or_class = getattr(module, node.func.id)
if is_classmethod(self.parent, self.component) and node.func.id == self.self_name:
function_or_class = self.parent
else:
function_or_class = getattr(module, node.func.id)
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
if self.parent and ast.dump(node.func.value) == ast.dump(ast_variable_load(self.self_name)):
function_or_class = self.parent
method_or_property = node.func.attr
else:
container = getattr(module, node.func.value.id)
function_or_class = getattr(container, node.func.attr)
if inspect.isclass(container):
function_or_class = container
method_or_property = node.func.attr
else:
function_or_class = getattr(container, node.func.attr)
if not function_or_class:
self.logger.debug(f'Component not supported: {ast_str(node)}')
return None
Expand Down Expand Up @@ -393,8 +473,13 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
if not values_found:
return args, kwargs

kwargs_value_dump = ast.dump(kwargs_value)
for node in [v for k, v in values_found if k == kwargs_name]:
if isinstance(node, ast.Call):
if ast_is_kwargs_pop_or_get(node, kwargs_value_dump):
param = get_kwargs_pop_or_get_parameter(node, self.component, self.parent, self.doc_params, self.logger)
kwargs = merge_parameters(param, kwargs)
continue
kwarg = ast_get_call_kwarg_with_value(node, kwargs_value)
params = []
if kwarg.arg:
Expand All @@ -408,11 +493,15 @@ def get_parameters_args_and_kwargs(self) -> Tuple[ParamList, ParamList]:
get_param_args = self.get_node_component(node)
if get_param_args:
params = get_signature_parameters(*get_param_args, logger=self.logger)
args, kwargs = split_args_and_kwargs(remove_given_parameters(node, params))
args, kwargs_ = split_args_and_kwargs(remove_given_parameters(node, params))
kwargs = merge_parameters(kwargs_, kwargs)
break
elif isinstance(node, ast_assign_type):
self_attr = self.parent and ast_is_attr_assign(node, self.self_name)
if self_attr:
kwargs = self.get_parameters_attr_use_in_members(self_attr)
params = self.get_parameters_attr_use_in_members(self_attr)
kwargs = merge_parameters(params, kwargs)
break
else:
self.logger.debug(f'Unsupported type of assign: {ast_str(node)}')

Expand Down Expand Up @@ -448,8 +537,9 @@ def get_parameters_call_attr(self, attr_name: str, attr_value: ast.AST) -> Optio
def get_parameters(self) -> ParamList:
if self.component is None:
return []
params, args_idx, kwargs_idx = get_signature_parameters_and_indexes(self.component, self.parent, self.logger)
params, args_idx, kwargs_idx, doc_params = get_signature_parameters_and_indexes(self.component, self.parent, self.logger)
if args_idx >= 0 or kwargs_idx >= 0:
self.doc_params = doc_params
with mro_context(self.parent):
args, kwargs = self.get_parameters_args_and_kwargs()
params = replace_args_and_kwargs(params, args, kwargs)
Expand All @@ -462,7 +552,7 @@ def get_parameters_by_assumptions(
logger: Union[bool, str, dict, logging.Logger] = True,
) -> ParamList:
component, parent, method_name = get_component_and_parent(function_or_class, method_name)
params, args_idx, kwargs_idx = get_signature_parameters_and_indexes(component, parent, logger)
params, args_idx, kwargs_idx, _ = get_signature_parameters_and_indexes(component, parent, logger)

if parent and (args_idx >= 0 or kwargs_idx >= 0):
with mro_context(parent):
Expand Down
59 changes: 57 additions & 2 deletions jsonargparse_tests/test_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def __init__(self, km2: int = 2, **kwargs):
super().__init__(**kwargs)

class ClassM3(ClassM1):
def __init__(self, km3: int=0, **kwargs):
def __init__(self, km3: int = 0, **kwargs):
"""
Args:
km3: help for km3
"""
super().__init__(**kwargs)

class ClassP:
def __init__(self, kp1: int=1, **kw):
def __init__(self, kp1: int = 1, **kw):
"""
Args:
kp1: help for kp1
Expand All @@ -197,6 +197,25 @@ def __init__(self, kp1: int=1, **kw):
def data(self):
return function_no_args_no_kwargs(**self._kw)

class ClassS1:
def __init__(self, ks1: int = 2, **kw):
"""
Args:
ks1: help for ks1
"""
self.ks1 = ks1

@classmethod
def classmethod_s(cls, **kwargs):
return cls(**kwargs)

class ClassS2:
def __init__(self, **kwargs):
self.kwargs = kwargs

def run_classmethod_s(self):
return ClassS1.classmethod_s(**self.kwargs)

class ClassM(ClassM2, ClassM3):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -260,6 +279,19 @@ def function_make_class_b(*args, k1: str = '-', **kwargs):
"""
return ClassB.make(*args, **kwargs)

def function_pop_get_from_kwargs(kn1: int = 0, **kw):
"""
Args:
kn1: help for kn1
kn2: help for kn2
kn3: help for kn3
k2: help for k2
"""
k2 = kw.pop('k2', [1])
kn2 = kw.pop('kn2', 0.5)
kn3 = kw.get('kn3', {})
return function_no_args_no_kwargs(**kw)

def function_with_bug(**kws):
return does_not_exist(**kws) # pylint: disable=undefined-variable

Expand All @@ -268,6 +300,9 @@ def function_unsupported_component(**kwds):
shuffle(select)
getattr(calendar, f'{select[0]}Calendar')(**kwds)

def function_module_class(**kwds):
return calendar.Calendar(**kwds)


@contextmanager
def source_unavailable():
Expand Down Expand Up @@ -373,6 +408,12 @@ def test_get_params_classmethod_make_class(self):
with source_unavailable():
assert_params(self, get_params(ClassB.make), ['pkcm1', 'kcm1'])

def test_get_params_classmethod_instantiate_from_cls(self):
assert_params(self, get_params(ClassS1, 'classmethod_s'), ['ks1'])
assert_params(self, get_params(ClassS2), ['ks1'])
with source_unavailable():
assert_params(self, get_params(ClassS1, 'classmethod_s'), [])


class GetFunctionParametersTests(unittest.TestCase):

Expand All @@ -397,6 +438,20 @@ def test_get_params_function_call_classmethod(self):
with source_unavailable():
assert_params(self, get_params(function_make_class_b), ['k1'])

def test_get_params_function_pop_get_from_kwargs(self):
with self.assertLogs(logger, level='DEBUG') as log:
params = get_params(function_pop_get_from_kwargs, logger=logger)
assert_params(self, params, ['kn1', 'k2', 'kn2', 'kn3', 'pk1'])
self.assertIsNone(params[1].default)
self.assertIn('Unsupported kwargs pop/get default', log.output[0])
with source_unavailable():
assert_params(self, get_params(function_pop_get_from_kwargs), ['kn1'])

def test_get_params_function_module_class(self):
params = get_params(function_module_class)
self.assertEqual(['firstweekday'], [p.name for p in params])


class OtherTests(unittest.TestCase):

def test_unsupported_component(self):
Expand Down

0 comments on commit 681c7e9

Please sign in to comment.