From 12e5a9af81c44286ffacc8001e0024734bad040a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 4 Jan 2025 06:40:08 +0000 Subject: [PATCH] Rename `_get_dummy_mm_inputs` Signed-off-by: DarkLight1337 --- tests/multimodal/test_processing.py | 5 +++- vllm/model_executor/models/aria.py | 2 +- vllm/model_executor/models/blip2.py | 2 +- vllm/model_executor/models/chameleon.py | 2 +- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/llava.py | 2 +- .../model_executor/models/llava_next_video.py | 2 +- vllm/model_executor/models/llava_onevision.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/qwen2_audio.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/ultravox.py | 2 +- vllm/multimodal/processing.py | 26 ++++++++++++------- 13 files changed, 32 insertions(+), 21 deletions(-) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 40d32e59a9b8b..b32faa699ebf2 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -723,7 +723,10 @@ def _test_processing_cache_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text + prompt = baseline_processor._get_dummy_processor_inputs( + model_config.max_model_len, + mm_counts, + ).prompt_text # Drop unnecessary keys and test single -> multi conversion if rng.rand() < simplify_rate: diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index c7a19267ab14f..2fd4262a9d3b9 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -488,7 +488,7 @@ def _get_prompt_replacements( ) ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index e33efe4ddc9b2..b3ecb2f22dc19 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -457,7 +457,7 @@ def apply( return result - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 77d4c6e565623..1ad44678a591d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -90,7 +90,7 @@ def _get_prompt_replacements( ) ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index a7b6f68046309..7cd58fbc7cf21 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -208,7 +208,7 @@ def apply( return result - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 745befcda9f3e..d522378e0bebb 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -149,7 +149,7 @@ def _get_dummy_image_size(self) -> ImageSize: def _get_image_token(self) -> str: raise NotImplementedError - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index bafc172a86f9d..66c1b5734efd8 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -162,7 +162,7 @@ def get_replacement(item_idx: int): ), ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 08b8f095c72ab..59c4d56fbe36d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -296,7 +296,7 @@ def get_video_replacement(item_idx: int): ), ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index b014d75b28aa6..7aa9d58d1d348 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -442,7 +442,7 @@ def _apply_prompt_replacements( return token_ids, text, placeholders - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 10778ff76c884..bc3bb1f79b407 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -191,7 +191,7 @@ def _always_apply_prompt_replacements(self) -> bool: # tokens than the number of audio items) return True - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index b1df14de56f1d..f872cebf006fe 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -900,7 +900,7 @@ def _get_mm_fields_config( video_grid_thw=MultiModalFieldConfig.batched("video"), ) - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 733d29767eecf..6ad4661e3bb8d 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -173,7 +173,7 @@ def get_replacement_ultravox(item_idx: int): ) ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 694d70a9e0462..a32e203fdcae6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -799,7 +799,7 @@ def _apply_hf_processor_missing( # Some HF processors (e.g. Qwen2-VL) expect corresponding # multi-modal tokens to be in the prompt text - dummy_inputs = self._get_dummy_mm_inputs( + dummy_inputs = self._get_dummy_processor_inputs( self.ctx.model_config.max_model_len, mm_missing_counts, ) @@ -1164,7 +1164,7 @@ def _get_dummy_videos( return [video] * num_videos @abstractmethod - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], @@ -1194,6 +1194,19 @@ def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]: return mm_limits + def _get_dummy_mm_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalInputsV2: + processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts) + + return self.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + def get_dummy_data(self, seq_len: int) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData @@ -1207,13 +1220,7 @@ def get_dummy_data(self, seq_len: int) -> DummyData: "returned by `get_mm_max_tokens_per_item` " f"({set(mm_max_tokens_per_item.keys())})") - processor_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - mm_inputs = self.apply( - prompt_text=processor_inputs.prompt_text, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - ) - + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] @@ -1243,6 +1250,7 @@ def get_dummy_data(self, seq_len: int) -> DummyData: "short. To avoid this, you should increase `max_model_len`, " "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, total_len, total_placeholders_by_modality) + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))