Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Roberta embedding models #9387

Merged
merged 26 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f7e23fb
support head size 32
maxdebayser Oct 22, 2024
10ebc9e
add support for Roberta models
maxdebayser Oct 15, 2024
b457cc5
fix after refactoring
maxdebayser Nov 11, 2024
3fe28f6
Review suggestions
flaviabeo Nov 12, 2024
5b75f4a
Merge branch 'upstream_main' into roberta
flaviabeo Nov 12, 2024
971acea
Fixes conflicts with new upstream changes
flaviabeo Nov 12, 2024
18a2d58
Merge changes fixes
flaviabeo Nov 12, 2024
40ac579
More fixed related to the upstream merge
flaviabeo Nov 12, 2024
e171896
Adds test for roberta model executor
flaviabeo Nov 12, 2024
55912f9
Asserts for Roberta models instance
flaviabeo Nov 12, 2024
6f06a76
Fix space for linting
flaviabeo Nov 12, 2024
d4c8849
Fix space for linting
flaviabeo Nov 12, 2024
b9e64b1
Modifies test for multilingual-e5-large
flaviabeo Nov 12, 2024
366a992
Fix linting in test
flaviabeo Nov 13, 2024
aed1216
Merge branch 'upstream_main' into roberta
flaviabeo Nov 13, 2024
aae474e
trigger ci
flaviabeo Nov 13, 2024
07c931c
finish generalizing the Bert classes
maxdebayser Nov 13, 2024
4495a50
Skips test for ROCm unsupported platform
flaviabeo Nov 13, 2024
49e8381
fix roberta position_ids
maxdebayser Nov 14, 2024
1267bba
add assert to verify assumption
maxdebayser Nov 14, 2024
49cc57b
improve assert
maxdebayser Nov 14, 2024
0f334ae
add model to embedding test
maxdebayser Nov 14, 2024
f27aae1
Remove encoder embedding model for compile test
maxdebayser Nov 14, 2024
44a9d22
trigger ci
maxdebayser Nov 14, 2024
9f31bd5
trigger ci
maxdebayser Nov 14, 2024
80ead23
trigger ci
maxdebayser Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ void paged_attention_v1_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
Expand Down
3 changes: 3 additions & 0 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ void paged_attention_v2_launcher(
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
Expand Down
6 changes: 6 additions & 0 deletions csrc/cpu/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ void paged_attention_v1_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down Expand Up @@ -702,6 +705,9 @@ void paged_attention_v2_impl_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

switch (head_size) {
case 32:
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
Expand Down
44 changes: 44 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")

MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
"intfloat/multilingual-e5-large")
REVISION_ROBERTA = os.environ.get("REVISION", "main")


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
Expand Down Expand Up @@ -48,3 +53,42 @@ def test_model_loading_with_params(vllm_runner):
assert model._pooler.normalize
# assert output
assert output


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_roberta_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
assert not model_config.encoder_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_norm

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize

# assert output
assert output
2 changes: 2 additions & 0 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-base-en-v1.5",
"BAAI/bge-multilingual-gemma2",
"intfloat/multilingual-e5-large",
]

ENCODER_ONLY = [
"BAAI/bge-base-en-v1.5",
"intfloat/multilingual-e5-large",
]


Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
return [32, 64, 80, 96, 112, 128, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 120, 128, 192, 256]
return [32, 64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
35 changes: 23 additions & 12 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import BertConfig

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -305,14 +305,16 @@ def forward(self, hidden_states: torch.Tensor,

class BertModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
embedding_class: type = BertEmbedding):
super().__init__()

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.embeddings = BertEmbedding(config)
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
Expand Down Expand Up @@ -382,13 +384,9 @@ class BertEmbeddingModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)

def forward(
self,
Expand All @@ -415,3 +413,16 @@ def pooler(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=BertEmbedding)

def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return Pooler.from_config_with_defaults(pooler_config,
pooling_type=PoolingType.CLS,
normalize=True,
softmax=False)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
Expand Down
117 changes: 117 additions & 0 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import List, Optional

import torch
from torch import nn
from transformers import RobertaConfig

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.sequence import IntermediateTensors


class RobertaEmbedding(nn.Module):

def __init__(self, config: RobertaConfig):
super().__init__()
self.size = config.hidden_size
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.padding_idx = config.pad_token_id
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size,
padding_idx=self.padding_idx)

self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )

self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")

def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()

# Input embeddings.
inputs_embeds = self.word_embeddings(input_ids)

# TODO: figure out if there is a better way
# to make to make position ids start at padding_idx + 1
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
position_ids += self.padding_idx + 1

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)

# Token type embeddings. (TODO: move off hotpath?)
token_type_embeddings = self.token_type_embeddings(
torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device))

embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings


class RobertaEmbeddingModel(BertEmbeddingModel):
"""A model that uses Roberta to provide embedding functionalities.

This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.

Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def _build_model(self,
vllm_config: VllmConfig,
prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config,
prefix=prefix,
embedding_class=RobertaEmbedding)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:

# Verify assumption that position are always a sequence from
# 0 to N. (Actually here we just check 0 and N to simplify).
# This is important to fix the position which are assumed to
# start from padding_idx + 1 instead of 0 in the Roberta models.
assert hasattr(attn_metadata, "seq_lens_tensor")
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
start_pos = torch.cat(
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
cumulative[:-1]))
assert len(torch.nonzero(positions[start_pos])) == 0
end_pos = cumulative - 1
last_tokens = attn_metadata.seq_lens_tensor - 1
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0

return super().forward(input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)