From 02ed8a1fbe41e3ad1bc04fd29b754facd28e329f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=83?= Date: Thu, 13 Feb 2025 22:17:57 +0800 Subject: [PATCH] [Misc] Qwen2.5-VL Optimization (#13155) --- vllm/model_executor/models/qwen2_5_vl.py | 61 ++++++++++-------------- vllm/model_executor/models/qwen2_vl.py | 37 ++++++++------ 2 files changed, 47 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d4c48dbdab13c..6aec99b3f9641 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -45,6 +45,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -271,8 +272,13 @@ def forward( q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN + q = apply_rotary_pos_emb_vision(q, + rotary_pos_emb, + use_flash_attn=use_flash_attn) + k = apply_rotary_pos_emb_vision(k, + rotary_pos_emb, + use_flash_attn=use_flash_attn) if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( @@ -296,20 +302,23 @@ def forward( "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: - seq_length = q.size(1) - q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + # Execute attention entry by entry for speed & less VRAM. + outputs = [] for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) - context_layer = rearrange(output, "b h s d -> b s h d ") + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -327,25 +336,6 @@ def forward( return output -class Qwen2RMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class Qwen2_5_VisionBlock(nn.Module): def __init__( @@ -516,8 +506,7 @@ def __init__( hidden_size=self.hidden_size, ) - # NOTE: We use torch native RMSNorm here for precision purposes. - norm_layer = partial(Qwen2RMSNorm, eps=norm_eps) + norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d3294a4d4a3b6..961f53cef1379 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: + freqs: torch.Tensor, + use_flash_attn=False) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() - output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + apply_rotary_emb = apply_rotary_emb_torch + if use_flash_attn: + from flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) return output @@ -336,20 +340,23 @@ def forward( "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.TORCH_SDPA: - seq_length = q.size(1) - q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + # Execute attention entry by entry for speed & less VRAM. + outputs = [] for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = True - output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) - context_layer = rearrange(output, "b h s d -> b s h d ") + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask