Skip to content

Commit

Permalink
fix decode dummy data
Browse files Browse the repository at this point in the history
Signed-off-by: yan ma <[email protected]>
  • Loading branch information
yma11 committed Feb 25, 2025
1 parent dc3c3f3 commit 0b2266f
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,21 +476,13 @@ def create_dummy_seq_group_metadata(self,
num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks,
max_mm_tokens) // self.block_size
cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks
prompt_token_ids = [0] * input_len
if is_prompt:
image_data = encoder_dummy_data.multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
assert is_list_of(image_data, Image.Image)
text_prompt_len = input_len - 1 - len(image_data)
# for prompt like '<|image|><|image|><|begin_of_text|>...', token
# ids will be '128256 128256 128000 ...'
prompt_token_ids = [128256] * len(image_data) + [
128000
] + [0] * text_prompt_len
output_token_ids = [1] * output_len
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
seq_data = SequenceData(prompt_token_ids_array)
decoder_dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=False)
seq_data = decoder_dummy_data.seq_data
seq_data.output_token_ids = output_token_ids

return SequenceGroupMetadata(
Expand All @@ -500,9 +492,9 @@ def create_dummy_seq_group_metadata(self,
sampling_params=sampling_params,
block_tables=block_tables,
encoder_seq_data=encoder_dummy_data.seq_data,
multi_modal_data=encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=encoder_dummy_data.
multi_modal_placeholders,
multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data,
multi_modal_placeholders=decoder_dummy_data.
multi_modal_placeholders or encoder_dummy_data.multi_modal_placeholders,
cross_block_table=cross_block_table)

def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
Expand Down

0 comments on commit 0b2266f

Please sign in to comment.