Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Recursively evaluate guarded code #393

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import ast
import importlib
import inspect
import re
import sys
Expand Down Expand Up @@ -404,32 +405,56 @@ def get_all_type_hints(autodoc_mock_imports: list[str], obj: Any, name: str) ->
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID = set()


def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None: # noqa: C901
if hasattr(obj, "__module__") and obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED:
return # already processed module
if not hasattr(obj, "__globals__"): # classes with __slots__ do not have this
return # if lacks globals nothing we can do
if id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID:
return # already processed object
_TYPE_GUARD_IMPORTS_RESOLVED.add(obj.__module__)
if obj.__module__ not in sys.builtin_module_names:
if hasattr(obj, "__globals__"):
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__))

module = inspect.getmodule(obj)
if module:
def _should_skip_guarded_import_resolution(obj: Any) -> bool:
if isinstance(obj, types.ModuleType):
return False # Don't skip modules

if not hasattr(obj, "__globals__"):
return True # Skip objects without __globals__

if hasattr(obj, "__module__"):
return obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED or obj.__module__ in sys.builtin_module_names

return id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID


def _execute_guarded_code(autodoc_mock_imports: list[str], obj: Any, module_code: str) -> None:
for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code):
guarded_code = textwrap.dedent(part)
try:
try:
module_code = inspect.getsource(module)
except (TypeError, OSError):
... # no source code => no type guards
else:
for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code):
guarded_code = textwrap.dedent(part)
try:
with mock(autodoc_mock_imports):
exec(guarded_code, obj.__globals__) # noqa: S102
except Exception as exc: # noqa: BLE001
_LOGGER.warning("Failed guarded type import with %r", exc)
with mock(autodoc_mock_imports):
exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102
except ImportError as exc:
# ImportError might have occurred because the module has guarded code as well,
# so we recurse on the module.
if exc.name:
_resolve_type_guarded_imports(autodoc_mock_imports, importlib.import_module(exc.name))

# Retry the guarded code and see if it works now after resolving all nested type guards.
with mock(autodoc_mock_imports):
exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102
except Exception as exc: # noqa: BLE001
_LOGGER.warning("Failed guarded type import with %r", exc)


def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None:
if _should_skip_guarded_import_resolution(obj):
return

if hasattr(obj, "__globals__"):
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__))

module = inspect.getmodule(obj)

if module:
try:
module_code = inspect.getsource(module)
except (TypeError, OSError):
... # no source code => no type guards
else:
_TYPE_GUARD_IMPORTS_RESOLVED.add(module.__name__)
_execute_guarded_code(autodoc_mock_imports, obj, module_code)


def _get_type_hint(autodoc_mock_imports: list[str], name: str, obj: Any) -> dict[str, Any]:
Expand Down
6 changes: 6 additions & 0 deletions tests/roots/test-resolve-typing-guard/demo_typing_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from decimal import Decimal
from typing import Sequence

from demo_typing_guard_dummy import Literal # guarded by another `if TYPE_CHECKING` in demo_typing_guard_dummy


if typing.TYPE_CHECKING:
from typing import AnyStr
Expand Down Expand Up @@ -52,6 +54,10 @@ def guarded(self, item: Decimal) -> None:
"""


def func(_x: Literal) -> None:
...


__all__ = [
"a",
"ValueError",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from viktor import AI # module part of autodoc_mock_imports # noqa: F401

if TYPE_CHECKING:
# Nested type guard
from typing import Literal # noqa: F401


class AnotherClass:
"""Another class is here"""