diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 44923376c8be9..e9d253e692191 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -225,8 +225,8 @@ def _build_elem( @abstractmethod def build_elems( self, - key: str, modality: str, + key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: """ @@ -277,11 +277,11 @@ class MultiModalBatchedField(BaseMultiModalField): def build_elems( self, - key: str, modality: str, + key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - return [self._build_elem(key, modality, item) for item in data] + return [self._build_elem(modality, key, item) for item in data] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): @@ -320,11 +320,11 @@ class MultiModalFlatField(BaseMultiModalField): def build_elems( self, - key: str, modality: str, + key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - return [self._build_elem(key, modality, data[s]) for s in self.slices] + return [self._build_elem(modality, key, data[s]) for s in self.slices] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): @@ -365,11 +365,11 @@ class MultiModalSharedField(BaseMultiModalField): def build_elems( self, - key: str, modality: str, + key: str, data: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - return [self._build_elem(key, modality, data)] * self.batch_size + return [self._build_elem(modality, key, data)] * self.batch_size def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: return batch[0]