From 0b2266f6f56ec1cc4a3f19605d3dda6e246bd861 Mon Sep 17 00:00:00 2001 From: yan ma Date: Tue, 25 Feb 2025 20:47:56 +0800 Subject: [PATCH] fix decode dummy data Signed-off-by: yan ma --- vllm/worker/hpu_enc_dec_model_runner.py | 26 +++++++++---------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index f3a5d88563b7e..58fe9cfd5c5ea 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -476,21 +476,13 @@ def create_dummy_seq_group_metadata(self, num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks, max_mm_tokens) // self.block_size cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks - prompt_token_ids = [0] * input_len - if is_prompt: - image_data = encoder_dummy_data.multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_data = [image_data] - assert is_list_of(image_data, Image.Image) - text_prompt_len = input_len - 1 - len(image_data) - # for prompt like '<|image|><|image|><|begin_of_text|>...', token - # ids will be '128256 128256 128000 ...' - prompt_token_ids = [128256] * len(image_data) + [ - 128000 - ] + [0] * text_prompt_len output_token_ids = [1] * output_len - prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 - seq_data = SequenceData(prompt_token_ids_array) + decoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=False) + seq_data = decoder_dummy_data.seq_data seq_data.output_token_ids = output_token_ids return SequenceGroupMetadata( @@ -500,9 +492,9 @@ def create_dummy_seq_group_metadata(self, sampling_params=sampling_params, block_tables=block_tables, encoder_seq_data=encoder_dummy_data.seq_data, - multi_modal_data=encoder_dummy_data.multi_modal_data, - multi_modal_placeholders=encoder_dummy_data. - multi_modal_placeholders, + multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=decoder_dummy_data. + multi_modal_placeholders or encoder_dummy_data.multi_modal_placeholders, cross_block_table=cross_block_table) def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: