diff --git a/tests/test_utils.py b/tests/test_utils.py index 0285b00d73be1..14d2fbd63b90d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,9 +7,9 @@ import torch from vllm_test_utils import monitor -from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, - get_open_port, memory_profiling, merge_async_iterators, - supports_kw) +from vllm.utils import (FlexibleArgumentParser, PlaceholderModule, + StoreBoolean, deprecate_kwargs, get_open_port, + memory_profiling, merge_async_iterators, supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test @@ -323,3 +323,44 @@ def measure_current_non_torch(): del weights lib.cudaFree(handle1) lib.cudaFree(handle2) + + +def test_placeholder_module_error_handling(): + placeholder = PlaceholderModule("placeholder_1234") + + def build_ctx(): + return pytest.raises(ModuleNotFoundError, + match="No module named") + + with build_ctx(): + int(placeholder) + + with build_ctx(): + placeholder() + + with build_ctx(): + _ = placeholder.some_attr + + with build_ctx(): + # Test conflict with internal __name attribute + _ = placeholder.name + + # OK to print the placeholder or use it in a f-string + _ = repr(placeholder) + _ = str(placeholder) + + # No error yet; only error when it is used downstream + placeholder_attr = placeholder.placeholder_attr("attr") + + with build_ctx(): + int(placeholder_attr) + + with build_ctx(): + placeholder_attr() + + with build_ctx(): + _ = placeholder_attr.some_attr + + with build_ctx(): + # Test conflict with internal __module attribute + _ = placeholder_attr.module diff --git a/vllm/utils.py b/vllm/utils.py index 0b0905e675245..487088591ebc2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -46,7 +46,7 @@ import zmq.asyncio from packaging.version import Version from torch.library import Library -from typing_extensions import ParamSpec, TypeIs, assert_never +from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger @@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies(): } -@dataclass(frozen=True) -class PlaceholderModule: +class _PlaceholderBase: + """ + Disallows downstream usage of placeholder modules. + + We need to explicitly override each dunder method because + :meth:`__getattr__` is not called when they are accessed. + + See also: + [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) + """ + + def __getattr__(self, key: str) -> Never: + """ + The main class should implement this to throw an error + for attribute accesses representing downstream usage. + """ + raise NotImplementedError + + # [Basic customization] + + def __lt__(self, other: object): + return self.__getattr__("__lt__") + + def __le__(self, other: object): + return self.__getattr__("__le__") + + def __eq__(self, other: object): + return self.__getattr__("__eq__") + + def __ne__(self, other: object): + return self.__getattr__("__ne__") + + def __gt__(self, other: object): + return self.__getattr__("__gt__") + + def __ge__(self, other: object): + return self.__getattr__("__ge__") + + def __hash__(self): + return self.__getattr__("__hash__") + + def __bool__(self): + return self.__getattr__("__bool__") + + # [Callable objects] + + def __call__(self, *args: object, **kwargs: object): + return self.__getattr__("__call__") + + # [Container types] + + def __len__(self): + return self.__getattr__("__len__") + + def __getitem__(self, key: object): + return self.__getattr__("__getitem__") + + def __setitem__(self, key: object, value: object): + return self.__getattr__("__setitem__") + + def __delitem__(self, key: object): + return self.__getattr__("__delitem__") + + # __missing__ is optional according to __getitem__ specification, + # so it is skipped + + # __iter__ and __reversed__ have a default implementation + # based on __len__ and __getitem__, so they are skipped. + + # [Numeric Types] + + def __add__(self, other: object): + return self.__getattr__("__add__") + + def __sub__(self, other: object): + return self.__getattr__("__sub__") + + def __mul__(self, other: object): + return self.__getattr__("__mul__") + + def __matmul__(self, other: object): + return self.__getattr__("__matmul__") + + def __truediv__(self, other: object): + return self.__getattr__("__truediv__") + + def __floordiv__(self, other: object): + return self.__getattr__("__floordiv__") + + def __mod__(self, other: object): + return self.__getattr__("__mod__") + + def __divmod__(self, other: object): + return self.__getattr__("__divmod__") + + def __pow__(self, other: object, modulo: object = ...): + return self.__getattr__("__pow__") + + def __lshift__(self, other: object): + return self.__getattr__("__lshift__") + + def __rshift__(self, other: object): + return self.__getattr__("__rshift__") + + def __and__(self, other: object): + return self.__getattr__("__and__") + + def __xor__(self, other: object): + return self.__getattr__("__xor__") + + def __or__(self, other: object): + return self.__getattr__("__or__") + + # r* and i* methods have lower priority than + # the methods for left operand so they are skipped + + def __neg__(self): + return self.__getattr__("__neg__") + + def __pos__(self): + return self.__getattr__("__pos__") + + def __abs__(self): + return self.__getattr__("__abs__") + + def __invert__(self): + return self.__getattr__("__invert__") + + # __complex__, __int__ and __float__ have a default implementation + # based on __index__, so they are skipped. + + def __index__(self): + return self.__getattr__("__index__") + + def __round__(self, ndigits: object = ...): + return self.__getattr__("__round__") + + def __trunc__(self): + return self.__getattr__("__trunc__") + + def __floor__(self): + return self.__getattr__("__floor__") + + def __ceil__(self): + return self.__getattr__("__ceil__") + + # [Context managers] + + def __enter__(self): + return self.__getattr__("__enter__") + + def __exit__(self, *args: object, **kwargs: object): + return self.__getattr__("__exit__") + + +class PlaceholderModule(_PlaceholderBase): """ A placeholder object to use when a module does not exist. This enables more informative errors when trying to access attributes of a module that does not exists. """ - name: str + + def __init__(self, name: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__name = name def placeholder_attr(self, attr_path: str): return _PlaceholderModuleAttr(self, attr_path) def __getattr__(self, key: str): - name = self.name + name = self.__name try: - importlib.import_module(self.name) + importlib.import_module(name) except ImportError as exc: for extra, names in get_vllm_optional_dependencies().items(): if name in names: @@ -1657,17 +1816,21 @@ def __getattr__(self, key: str): "when the original module can be imported") -@dataclass(frozen=True) -class _PlaceholderModuleAttr: - module: PlaceholderModule - attr_path: str +class _PlaceholderModuleAttr(_PlaceholderBase): + + def __init__(self, module: PlaceholderModule, attr_path: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__module = module + self.__attr_path = attr_path def placeholder_attr(self, attr_path: str): - return _PlaceholderModuleAttr(self.module, - f"{self.attr_path}.{attr_path}") + return _PlaceholderModuleAttr(self.__module, + f"{self.__attr_path}.{attr_path}") def __getattr__(self, key: str): - getattr(self.module, f"{self.attr_path}.{key}") + getattr(self.__module, f"{self.__attr_path}.{key}") raise AssertionError("PlaceholderModule should not be used " "when the original module can be imported")