Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Gemma2 embedding model #9004

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def __init__(
SentenceTransformer(
model_name,
device="cpu",
trust_remote_code=True,
).to(dtype=torch_dtype))
else:
model_kwargs = model_kwargs if model_kwargs is not None else {}
Expand Down
11 changes: 10 additions & 1 deletion tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.

Run `pytest tests/models/test_llama_embedding.py`.
Run `pytest tests/models/embedding/language/test_embedding.py`.
"""
import pytest
import torch
import torch.nn.functional as F

MODELS = [
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2",
]


Expand All @@ -28,6 +29,14 @@ def test_models(
model: str,
dtype: str,
) -> None:
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]

with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,14 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer

residual = None
else:
assert intermediate_tensors is not None
Expand Down
82 changes: 82 additions & 0 deletions vllm/model_executor/models/gemma2_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn

from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput


class Gemma2EmbeddingModel(nn.Module):
"""A model that uses Gemma2 with additional embedding functionalities.

This class encapsulates the Gemma2Model and provides an interface for
embedding operations and customized pooling functions.

Attributes:
model: An instance of Gemma2Model used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(
self,
**kwargs,
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.forward(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
}

_MULTIMODAL_MODELS = {
Expand Down
Loading