Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][LoRA]LoRA support added for MiniCPMV2.6 #8943

Merged
merged 33 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2137917
init
jeejeelee Aug 6, 2024
5edda37
optimize minicpmv implementation
jeejeelee Aug 6, 2024
2ea5006
delete comment
jeejeelee Aug 6, 2024
42e846a
Trigger LoRA test
jeejeelee Aug 7, 2024
f881b65
Merge branch 'vllm-project:main' into minicpmv25-lora
jeejeelee Aug 7, 2024
414533b
Merge branch 'vllm-project:main' into minicpmv25-lora
jeejeelee Aug 8, 2024
2ee35d5
Merge branch 'vllm-project:main' into minicpmv25-lora
jeejeelee Aug 8, 2024
1261a58
Merge branch 'vllm-project:main' into minicpmv25-lora
jeejeelee Aug 12, 2024
d53aeb9
Address branch conflict
jeejeelee Sep 25, 2024
9eed235
Modify code
jeejeelee Sep 25, 2024
e4e3f46
Complete VL supports lora
jeejeelee Sep 25, 2024
65b5b08
Format code
jeejeelee Sep 25, 2024
9bf92d5
Clean code
jeejeelee Sep 25, 2024
561b4b7
Clean code
jeejeelee Sep 26, 2024
578deba
Clean code
jeejeelee Sep 26, 2024
9d0c4fd
Merge branch 'vllm-project:main' into minicpmv25-lora
jeejeelee Sep 26, 2024
910b650
Merge branch 'vllm-project:main' into minicpmv25-lora
jeejeelee Sep 27, 2024
99dacdf
Add unit test for minicpmv25
jeejeelee Sep 27, 2024
9b85373
Format code
jeejeelee Sep 27, 2024
bf4ee9d
Modify code
jeejeelee Sep 27, 2024
a9e724c
Modify module_mapping logic
jeejeelee Sep 27, 2024
be6c928
Add unit test
jeejeelee Sep 27, 2024
c9db73e
Modify unit test
jeejeelee Sep 28, 2024
bbfd3e0
Delete mincpmv25 distributed test
jeejeelee Sep 29, 2024
acc836a
Fix lora bug and modify minicpmv lora tests
jeejeelee Sep 29, 2024
27a7be4
Minicpmv26 support LoRA done
jeejeelee Sep 29, 2024
d65a4dc
Merge branch 'vllm-project:main' into minicpmv26-lora
jeejeelee Sep 29, 2024
114c4e0
Update minicpmv26 vpm
jeejeelee Sep 29, 2024
1b7b0ec
Done
jeejeelee Sep 29, 2024
f7f3172
Use the common `BaseResampler`
DarkLight1337 Sep 30, 2024
9ec1c65
Fix type annotation
DarkLight1337 Sep 30, 2024
28a8653
Remove comment
DarkLight1337 Sep 30, 2024
6f4cfd7
format
DarkLight1337 Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ def __init__(self, config: Idefics2VisionConfig):
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)

def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor:
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
Expand All @@ -84,8 +83,13 @@ def forward(
fill_value=0)

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()

if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
Expand Down Expand Up @@ -287,10 +291,12 @@ def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor:
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
100 changes: 34 additions & 66 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
import torch.types
from PIL import Image
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.resampler import BaseResampler
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (Resampler2,
Expand Down Expand Up @@ -106,58 +105,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)


class BaseResampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""

def __init__(
self,
num_queries: int,
embed_dim: int,
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
) -> None:
super().__init__()

self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads

self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
else:
# Maintain the same return value with ReplicatedLinear.forward
self.kv_proj = lambda *args, **kwargs: (
nn.Identity()(*args, **kwargs),
None,
)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))

def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)


class Resampler2_5(BaseResampler):

def __init__(
Expand Down Expand Up @@ -869,7 +816,35 @@ def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name


class MiniCPMV2_6(MiniCPMVBaseModel):
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]

embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
Expand All @@ -894,15 +869,8 @@ def init_llm(
name="model")

def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
from vllm.model_executor.models.na_vit import SiglipVisionTransformer

if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not support sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
model = Idefics2VisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
Expand All @@ -928,7 +896,7 @@ def get_vision_embedding(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
)
return vision_embedding

def get_vision_hidden_states(
Expand Down Expand Up @@ -960,12 +928,12 @@ def get_vision_hidden_states(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
)

return self.resampler(vision_embedding, tgt_sizes)

def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
return "resampler" in name


_SUPPORT_VERSION = {
Expand Down
Loading
Loading