diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e351d602189e2..2c09ca2c1407c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Union import torch @@ -336,7 +336,7 @@ def scaled_fp8_quant( """ # This code assumes batch_dim and num_tokens are flattened assert (input.ndim == 2) - shape = input.shape + shape: Union[Tuple[int, int], torch.Size] = input.shape if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index f13885ef0dab0..aefb5f438c5ad 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -53,9 +53,7 @@ class MultiModalInputs(_MultiModalInputsBase): """ @staticmethod - def _try_concat( - tensors: List[NestedTensors], - ) -> Union[GenericSequence[NestedTensors], NestedTensors]: + def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: """ If each input tensor in the batch has the same shape, return a single batched tensor; otherwise, return a list of :class:`NestedTensors` with @@ -105,7 +103,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: return { k: MultiModalInputs._try_concat(item_list) for k, item_list in item_lists.items() - } # type: ignore + } @staticmethod def as_kwargs(