Skip to content

Commit

Permalink
[Misc] Qwen2.5-VL Optimization (vllm-project#13155)
Browse files Browse the repository at this point in the history
  • Loading branch information
wulipc authored and kerthcet committed Feb 21, 2025
1 parent 3e3b72c commit a1406b3
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 51 deletions.
61 changes: 25 additions & 36 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -271,8 +272,13 @@ def forward(
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
q = apply_rotary_pos_emb_vision(q,
rotary_pos_emb,
use_flash_attn=use_flash_attn)
k = apply_rotary_pos_emb_vision(k,
rotary_pos_emb,
use_flash_attn=use_flash_attn)

if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
Expand All @@ -296,20 +302,23 @@ def forward(
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
Expand All @@ -327,25 +336,6 @@ def forward(
return output


class Qwen2RMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2_5_VisionBlock(nn.Module):

def __init__(
Expand Down Expand Up @@ -516,8 +506,7 @@ def __init__(
hidden_size=self.hidden_size,
)

# NOTE: We use torch native RMSNorm here for precision purposes.
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

Expand Down
37 changes: 22 additions & 15 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,


def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
freqs: torch.Tensor,
use_flash_attn=False) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
apply_rotary_emb = apply_rotary_emb_torch
if use_flash_attn:
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output


Expand Down Expand Up @@ -336,20 +340,23 @@ def forward(
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
Expand Down

0 comments on commit a1406b3

Please sign in to comment.