Skip to content

Commit

Permalink
fix: Recursively evaluate guarded code (#393)
Browse files Browse the repository at this point in the history
* Add test for nested guarded imports

* Handle nested type guards

* Refactor

---------

Co-authored-by: Bernát Gábor <[email protected]>
  • Loading branch information
Mr-Pepe and gaborbernat authored Oct 31, 2023
1 parent 5eb0fcf commit 3eeb664
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
75 changes: 50 additions & 25 deletions src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import ast
import importlib
import inspect
import re
import sys
Expand Down Expand Up @@ -404,32 +405,56 @@ def get_all_type_hints(autodoc_mock_imports: list[str], obj: Any, name: str) ->
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID = set()


def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None: # noqa: C901
if hasattr(obj, "__module__") and obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED:
return # already processed module
if not hasattr(obj, "__globals__"): # classes with __slots__ do not have this
return # if lacks globals nothing we can do
if id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID:
return # already processed object
_TYPE_GUARD_IMPORTS_RESOLVED.add(obj.__module__)
if obj.__module__ not in sys.builtin_module_names:
if hasattr(obj, "__globals__"):
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__))

module = inspect.getmodule(obj)
if module:
def _should_skip_guarded_import_resolution(obj: Any) -> bool:
if isinstance(obj, types.ModuleType):
return False # Don't skip modules

if not hasattr(obj, "__globals__"):
return True # Skip objects without __globals__

if hasattr(obj, "__module__"):
return obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED or obj.__module__ in sys.builtin_module_names

return id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID


def _execute_guarded_code(autodoc_mock_imports: list[str], obj: Any, module_code: str) -> None:
for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code):
guarded_code = textwrap.dedent(part)
try:
try:
module_code = inspect.getsource(module)
except (TypeError, OSError):
... # no source code => no type guards
else:
for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code):
guarded_code = textwrap.dedent(part)
try:
with mock(autodoc_mock_imports):
exec(guarded_code, obj.__globals__) # noqa: S102
except Exception as exc: # noqa: BLE001
_LOGGER.warning("Failed guarded type import with %r", exc)
with mock(autodoc_mock_imports):
exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102
except ImportError as exc:
# ImportError might have occurred because the module has guarded code as well,
# so we recurse on the module.
if exc.name:
_resolve_type_guarded_imports(autodoc_mock_imports, importlib.import_module(exc.name))

# Retry the guarded code and see if it works now after resolving all nested type guards.
with mock(autodoc_mock_imports):
exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102
except Exception as exc: # noqa: BLE001
_LOGGER.warning("Failed guarded type import with %r", exc)


def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None:
if _should_skip_guarded_import_resolution(obj):
return

if hasattr(obj, "__globals__"):
_TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__))

module = inspect.getmodule(obj)

if module:
try:
module_code = inspect.getsource(module)
except (TypeError, OSError):
... # no source code => no type guards
else:
_TYPE_GUARD_IMPORTS_RESOLVED.add(module.__name__)
_execute_guarded_code(autodoc_mock_imports, obj, module_code)


def _get_type_hint(autodoc_mock_imports: list[str], name: str, obj: Any) -> dict[str, Any]:
Expand Down
6 changes: 6 additions & 0 deletions tests/roots/test-resolve-typing-guard/demo_typing_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from decimal import Decimal
from typing import Sequence

from demo_typing_guard_dummy import Literal # guarded by another `if TYPE_CHECKING` in demo_typing_guard_dummy


if typing.TYPE_CHECKING:
from typing import AnyStr
Expand Down Expand Up @@ -52,6 +54,10 @@ def guarded(self, item: Decimal) -> None:
"""


def func(_x: Literal) -> None:
...


__all__ = [
"a",
"ValueError",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from viktor import AI # module part of autodoc_mock_imports # noqa: F401

if TYPE_CHECKING:
# Nested type guard
from typing import Literal # noqa: F401


class AnotherClass:
"""Another class is here"""

0 comments on commit 3eeb664

Please sign in to comment.