diff --git a/src/sphinx_autodoc_typehints/__init__.py b/src/sphinx_autodoc_typehints/__init__.py index caf9926..7220450 100644 --- a/src/sphinx_autodoc_typehints/__init__.py +++ b/src/sphinx_autodoc_typehints/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import ast +import importlib import inspect import re import sys @@ -404,32 +405,56 @@ def get_all_type_hints(autodoc_mock_imports: list[str], obj: Any, name: str) -> _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID = set() -def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None: # noqa: C901 - if hasattr(obj, "__module__") and obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED: - return # already processed module - if not hasattr(obj, "__globals__"): # classes with __slots__ do not have this - return # if lacks globals nothing we can do - if id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID: - return # already processed object - _TYPE_GUARD_IMPORTS_RESOLVED.add(obj.__module__) - if obj.__module__ not in sys.builtin_module_names: - if hasattr(obj, "__globals__"): - _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__)) - - module = inspect.getmodule(obj) - if module: +def _should_skip_guarded_import_resolution(obj: Any) -> bool: + if isinstance(obj, types.ModuleType): + return False # Don't skip modules + + if not hasattr(obj, "__globals__"): + return True # Skip objects without __globals__ + + if hasattr(obj, "__module__"): + return obj.__module__ in _TYPE_GUARD_IMPORTS_RESOLVED or obj.__module__ in sys.builtin_module_names + + return id(obj.__globals__) in _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID + + +def _execute_guarded_code(autodoc_mock_imports: list[str], obj: Any, module_code: str) -> None: + for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code): + guarded_code = textwrap.dedent(part) + try: try: - module_code = inspect.getsource(module) - except (TypeError, OSError): - ... # no source code => no type guards - else: - for _, part in _TYPE_GUARD_IMPORT_RE.findall(module_code): - guarded_code = textwrap.dedent(part) - try: - with mock(autodoc_mock_imports): - exec(guarded_code, obj.__globals__) # noqa: S102 - except Exception as exc: # noqa: BLE001 - _LOGGER.warning("Failed guarded type import with %r", exc) + with mock(autodoc_mock_imports): + exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102 + except ImportError as exc: + # ImportError might have occurred because the module has guarded code as well, + # so we recurse on the module. + if exc.name: + _resolve_type_guarded_imports(autodoc_mock_imports, importlib.import_module(exc.name)) + + # Retry the guarded code and see if it works now after resolving all nested type guards. + with mock(autodoc_mock_imports): + exec(guarded_code, getattr(obj, "__globals__", obj.__dict__)) # noqa: S102 + except Exception as exc: # noqa: BLE001 + _LOGGER.warning("Failed guarded type import with %r", exc) + + +def _resolve_type_guarded_imports(autodoc_mock_imports: list[str], obj: Any) -> None: + if _should_skip_guarded_import_resolution(obj): + return + + if hasattr(obj, "__globals__"): + _TYPE_GUARD_IMPORTS_RESOLVED_GLOBALS_ID.add(id(obj.__globals__)) + + module = inspect.getmodule(obj) + + if module: + try: + module_code = inspect.getsource(module) + except (TypeError, OSError): + ... # no source code => no type guards + else: + _TYPE_GUARD_IMPORTS_RESOLVED.add(module.__name__) + _execute_guarded_code(autodoc_mock_imports, obj, module_code) def _get_type_hint(autodoc_mock_imports: list[str], name: str, obj: Any) -> dict[str, Any]: diff --git a/tests/roots/test-resolve-typing-guard/demo_typing_guard.py b/tests/roots/test-resolve-typing-guard/demo_typing_guard.py index d24d602..74fd474 100644 --- a/tests/roots/test-resolve-typing-guard/demo_typing_guard.py +++ b/tests/roots/test-resolve-typing-guard/demo_typing_guard.py @@ -12,6 +12,8 @@ from decimal import Decimal from typing import Sequence + from demo_typing_guard_dummy import Literal # guarded by another `if TYPE_CHECKING` in demo_typing_guard_dummy + if typing.TYPE_CHECKING: from typing import AnyStr @@ -52,6 +54,10 @@ def guarded(self, item: Decimal) -> None: """ +def func(_x: Literal) -> None: + ... + + __all__ = [ "a", "ValueError", diff --git a/tests/roots/test-resolve-typing-guard/demo_typing_guard_dummy.py b/tests/roots/test-resolve-typing-guard/demo_typing_guard_dummy.py index 9a61ebd..28dcf07 100644 --- a/tests/roots/test-resolve-typing-guard/demo_typing_guard_dummy.py +++ b/tests/roots/test-resolve-typing-guard/demo_typing_guard_dummy.py @@ -1,7 +1,13 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from viktor import AI # module part of autodoc_mock_imports # noqa: F401 +if TYPE_CHECKING: + # Nested type guard + from typing import Literal # noqa: F401 + class AnotherClass: """Another class is here"""