diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 3ae74cc5cb7dd..0449f9354d0a2 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,24 +1,58 @@ import inspect -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch import vllm.envs as envs from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo +logger = init_logger(__name__) -def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): + +def support_torch_compile( + cls: Optional[type] = None, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None): """ A decorator to add support for compiling the forward method of a class. + Usage 1: use directly as a decorator without arguments: + + ```python + @support_torch_compile + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + Usage 2: use as a decorator with arguments: + + ```python + @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic dimensions of the argument. The dynamic dimensions can be either a single integer or a list of integers. - Depending on the value of arguments: + if `dynamic_arg_dims` is `None`, it is inferred from the type annotation + of the `forward` method, based on the following default rules: + + - if the argument is annotated as `torch.Tensor` or + `Optional[torch.Tensor]`, the first dimension will be + marked as dynamic. + - if the argument is annotated as `IntermediateTensors`, the first + dimension of all the tensors in the intermediate tensors + will be marked as dynamic. + + During runtime, when we actually mark dimensions of tensors, + it depends on the value of arguments: - if it is a single integer, the corresponding dimension of the argument will be marked as dynamic. @@ -38,11 +72,35 @@ def cls_decorator_helper(cls: type): if not hasattr(cls, 'forward'): raise TypeError("decorated class should have a forward method.") sig = inspect.signature(cls.forward) - for k in dynamic_arg_dims: + inferred_dynamic_arg_dims = dynamic_arg_dims + if inferred_dynamic_arg_dims is None: + inferred_dynamic_arg_dims = {} + for k, v in sig.parameters.items(): + if v.annotation in [ + torch.Tensor, Optional[torch.Tensor], + IntermediateTensors, Optional[IntermediateTensors] + ]: + inferred_dynamic_arg_dims[k] = 0 + + logger.debug(("Inferred dynamic dimensions for " + "forward method of %s: %s"), cls, + list(inferred_dynamic_arg_dims.keys())) + + if len(inferred_dynamic_arg_dims) == 0: + raise ValueError( + "No dynamic dimensions found in the forward method of " + f"{cls}. Please provide dynamic_arg_dims explicitly.") + + for k in inferred_dynamic_arg_dims: if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims) + + if cls is not None: + # use `support_torch_compile` as a decorator without arguments + assert isinstance(cls, type) + return cls_decorator_helper(cls) return cls_decorator_helper diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f958268741cd5..d79248f93f5ae 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -241,13 +241,7 @@ def forward( return hidden_states, residual -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - "positions": 0, - "inputs_embeds": 0, - "intermediate_tensors": 0, - }) +@support_torch_compile class Gemma2Model(nn.Module): def __init__( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fd88ae8b50402..c346e3e808e3f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -268,13 +268,7 @@ def forward( return hidden_states, residual -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - "positions": 0, - "inputs_embeds": 0, - "intermediate_tensors": 0, - }) +@support_torch_compile class LlamaModel(nn.Module): def __init__(