Skip to content

Commit

Permalink
Fix function-valued arguments (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
sizmailov authored Aug 30, 2023
1 parent 4d9f245 commit a387874
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 2 deletions.
18 changes: 18 additions & 0 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def handle_type(self, type_: type) -> QualifiedName:
self._add_import(result)
return result

def handle_value(self, value: Any) -> Value:
result = super().handle_value(value)
if inspect.isroutine(value) and result.is_print_safe:
self._add_import(QualifiedName.from_str(result.repr))
return result

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
Expand Down Expand Up @@ -237,6 +243,10 @@ def handle_type(self, type_: type) -> QualifiedName:
if result[0] == "builtins":
if result[1] == "NoneType":
return QualifiedName((Identifier("None"),))
if result[1] in ("function", "builtin_function_or_method"):
callable_t = self.parse_annotation_str("typing.Callable")
assert isinstance(callable_t, ResolvedType)
return callable_t.name
return QualifiedName(result[1:])

return result
Expand Down Expand Up @@ -363,6 +373,14 @@ def handle_type(self, type_: type) -> QualifiedName:
result = super().handle_type(type_)
return self._strip_current_module(result)

def handle_value(self, value: Any) -> Value:
result = super().handle_value(value)
if inspect.isroutine(value):
result.repr = str(
self._strip_current_module(QualifiedName.from_str(result.repr))
)
return result

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
Expand Down
10 changes: 9 additions & 1 deletion pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _is_descriptor(self, member: Any) -> bool:


class BaseParser(IParser):
def is_print_safe(self, value: Value) -> bool:
def is_print_safe(self, value: Any) -> bool:
value_type = type(value)
# Use exact type match, not `isinstance()` that allows inherited types pass
if value is None or value_type in (int, str):
Expand All @@ -180,6 +180,14 @@ def is_print_safe(self, value: Value) -> bool:
if not self.is_print_safe(k) or not self.is_print_safe(v):
return False
return True
if inspect.isfunction(value):
if (
(module_name := getattr(value, "__module__", None)) is not None
and "<" not in module_name
and (qual_name := getattr(value, "__qualname__", None)) is not None
and "<" not in qual_name
):
return True
if inspect.ismodule(value):
return True
return False
Expand Down
1 change: 1 addition & 0 deletions tests/py-demo/demo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Mock common layout library mixed of C/C++ core and pure python
# C++-based modules are prefixed with underscore

from . import pure_python
from .core import *

#
Expand Down
1 change: 1 addition & 0 deletions tests/py-demo/demo/pure_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import functions
43 changes: 43 additions & 0 deletions tests/py-demo/demo/pure_python/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
class _Dummy:
@staticmethod
def foo():
return 42


def search(a: int, b: list[int]) -> int:
...


def builtin_function_as_default_arg(func: type(len) = len):
...


def function_as_default_arg(func: type(search) = search):
...


def lambda_as_default_arg(callback=lambda val: 0):
...


def static_method_as_default_arg(callback=_Dummy.foo):
...


def arg_mix(
a: int,
b: float = 0.5,
/,
c: str = "",
*args: int,
x: int = 1,
y=search,
**kwargs: dict[int, str],
):
"""Mix of positional, kw and variadic args
Note:
The `inspect.getfullargspec` does not reflect presence
of pos-only args separator (/)
"""
...
3 changes: 2 additions & 1 deletion tests/stubs/demo/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from demo._bindings import (
values,
)

from . import _bindings, core
from . import _bindings, core, pure_python

__all__ = [
"aliases",
Expand All @@ -31,6 +31,7 @@ __all__ = [
"methods",
"numpy",
"properties",
"pure_python",
"stl",
"stl_bind",
"typing",
Expand Down
5 changes: 5 additions & 0 deletions tests/stubs/demo/pure_python/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from . import functions

__all__ = ["functions"]
40 changes: 40 additions & 0 deletions tests/stubs/demo/pure_python/functions.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import typing

__all__ = [
"arg_mix",
"builtin_function_as_default_arg",
"function_as_default_arg",
"lambda_as_default_arg",
"search",
"static_method_as_default_arg",
]

class _Dummy(object):
@staticmethod
def foo(): ...

def arg_mix(
a: int,
b: float = 0.5,
c: str = "",
*args: int,
x: int = 1,
y=search,
**kwargs: dict[int, str],
):
"""
Mix of positional, kw and variadic args
Note:
The `inspect.getfullargspec` does not reflect presence
of pos-only args separator (/)
"""

def builtin_function_as_default_arg(func: typing.Callable = ...): ...
def function_as_default_arg(func: typing.Callable = search): ...
def lambda_as_default_arg(callback=...): ...
def search(a: int, b: list[int]) -> int: ...
def static_method_as_default_arg(callback=_Dummy.foo): ...

0 comments on commit a387874

Please sign in to comment.