From 12836c521a864cb9fe8bf9893268bb543b6582d1 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 23 Oct 2024 22:03:44 -0400 Subject: [PATCH] [Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image (#9626) Signed-off-by: mgoin Signed-off-by: qishuai --- vllm/model_executor/models/mllama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 23e2b520e5b40..475364f322c62 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -795,17 +795,19 @@ def attention_with_mask( kv_len = k.shape[0] q = q.transpose(0, 1).view(self.num_local_key_value_heads, self.num_key_value_groups, q_len, - self.head_dim) + self.head_dim).contiguous() k = k.transpose(0, 1)[:, None, :, :].expand(self.num_local_key_value_heads, self.num_key_value_groups, - kv_len, self.head_dim) + kv_len, + self.head_dim).contiguous() v = v.transpose(0, 1)[:, None, :, :].expand(self.num_local_key_value_heads, self.num_key_value_groups, - kv_len, self.head_dim) + kv_len, + self.head_dim).contiguous() attention_mask = attention_mask.view(1, 1, q_len, kv_len) output = F.scaled_dot_product_attention(q, k,