Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support VLMs with transformers backend #13754

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ def chat(
]

tokenizer = self.get_tokenizer()

model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
Expand Down
268 changes: 245 additions & 23 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers import AutoModel, PreTrainedModel, LlavaConfig
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from vllm.attention import Attention, AttentionMetadata
Expand All @@ -37,11 +37,16 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry, MultiModalKwargs
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalInputs, PlaceholderRange

from .interfaces import SupportsQuant
from .interfaces import SupportsQuant, SupportsMultiModal
from .utils import maybe_prefix

Check failure on line 47 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:47:81: E501 Line too long (82 > 80)

logger = init_logger(__name__)

Check failure on line 49 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:49:81: E501 Line too long (92 > 80)


def vllm_flash_attention_forward(
Expand Down Expand Up @@ -119,10 +124,180 @@
)


class TransformersModel(nn.Module, SupportsQuant):
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
# NOTE: this means we don't check if return config type is same as requested
# VLLM on contrary always checks. In whcih cases we can have different config types tho?
return self.ctx.model_config.hf_config

def get_supported_mm_limits(self):
return {"image": None, "video": None}

def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens(), "video": 100}

def get_max_image_tokens(self) -> int:
# Is already an attribute in some VLMs and now reason to make it a required attribute
# TODO: @raushan add it for all VLM configs
return self.get_hf_config().image_seq_length

def get_hf_processor(self):
processor = cached_get_processor(self.ctx.model_config.model)
return processor


class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder):
def get_dummy_processor_inputs(
self,
seq_len,
mm_counts,
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
num_frames = 8

processor = self.info.get_hf_processor()
image_token = getattr(processor, "image_token", None)
video_token = getattr(processor, "video_token", None)

# TODO: raushan, we can have processor attr for `processor.max_output_size` which will infer
# max features for model in HF side. But imo we can just set a veru high resolution
# and the processor will return us pixels with correct max shape. Resolution 3kx3k is high enough

Check failure on line 165 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:165:81: E501 Line too long (84 > 80)
target_width = target_height = 3000

Check failure on line 166 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:166:81: E501 Line too long (96 > 80)

# NOTE: we can pass videos/images/audio to any processor With the new API used in MLLMs,
# HF processor will take the modality needed for model and ignore all others
mm_data = {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images
),
"video": self._get_dummy_videos(

Check failure on line 176 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:176:81: E501 Line too long (93 > 80)
width=target_width,
height=target_height,
num_frames=num_frames,
num_videos=num_videos,
)
}

prompt_text = video_token*num_videos if video_token is not None else image_token*num_images
return ProcessorInputs(
prompt_text=prompt_text,
mm_data=mm_data,
)


class MultiModalProcessor(BaseMultiModalProcessor):
def _get_prompt_replacements(
self,
mm_items,
hf_processor_mm_kwargs,
out_mm_kwargs: MultiModalKwargs,
):
return

def _get_mm_fields_config(

Check failure on line 200 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:200:81: E501 Line too long (100 > 80)
self,

Check failure on line 201 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:201:81: E501 Line too long (91 > 80)
hf_inputs,

Check failure on line 202 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:202:81: E501 Line too long (105 > 80)
hf_processor_mm_kwargs,
):
return dict(

Check failure on line 205 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:205:81: E501 Line too long (96 > 80)
pixel_values=MultiModalFieldConfig.batched("image"),

Check failure on line 206 in vllm/model_executor/models/transformers.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/transformers.py:206:81: E501 Line too long (84 > 80)
mm_token_type_ids=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.batched("video"),
image_embeds=MultiModalFieldConfig.batched("image"),
video_embeds=MultiModalFieldConfig.batched("video"),
)

def _apply_hf_processor_text_mm(
self,
prompt_text,
mm_items,
hf_processor_mm_kwargs,
):
"""
Apply the HF processor on the prompt text and multi-modal data
together.

In addition, return whether prompt replacements have been applied.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True

processed_data = self._call_hf_processor(
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
)
processed_data.update(passthrough_data)

prompt_ids, = processed_data.pop("input_ids").tolist()
mm_token_type_ids = processed_data.pop("mm_token_type_ids")

mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)

return prompt_ids, mm_kwargs, mm_token_type_ids

def apply(
self,
prompt,
mm_data,
hf_processor_mm_kwargs,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.

Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
mm_items = self._to_mm_items(mm_data)
prompt_ids, mm_kwargs, mm_token_type_ids = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)

# HF processor will return `mm_token_type_ids` from which
# we can infer mm_placeholders. Until then hardcode to make code run
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
mm_positions = torch.where(mm_token_type_ids == 1)[1]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
mm_tokens_per_modality = hf_processor._get_num_mm_tokens(
image_inputs=mm_kwargs.get_hf_inputs("image"),
video_inputs=mm_kwargs.get_hf_inputs("video"),
)

mm_placeholders = {}
for modality in mm_tokens_per_modality:
split_sizes = mm_tokens_per_modality[modality]
if split_sizes != 0:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
ranges = [
PlaceholderRange(offset=positions[0].item(), length=positions.shape[0])
for positions in chunked_mm_positions
]
mm_placeholders = {modality: ranges}

return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=None,
mm_placeholders=mm_placeholders,
)


@MULTIMODAL_REGISTRY.register_processor(MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersModel(nn.Module, SupportsQuant, SupportsMultiModal):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the same class to support both text-only models and multimodal models makes it difficult to maintain V1 compatibility, see: #13157

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arf, in my mind we would only need 1 "wrapper" to rule them all 😃 But makes sense. We can also have something like is multimodal, but might no work on your side!

embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it

def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
Expand All @@ -132,12 +307,13 @@
cache_config = vllm_config.cache_config

self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
self.text_config = config.get_text_config()
self.vocab_size = self.text_config.vocab_size
self.unpadded_vocab_size = self.text_config.vocab_size

self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
attn_implementation={"text_config": "vllm", "vision_config": "eager"},
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current way doesn't work with LLMs. Setting self.config.get_text_config().attn_implementation="vllm" will be much more generic, but that needs a fix on transformers first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed! Let's work on refactoring vision models! (tho not "that" important because paged is mostly for text, here using flex / sdpa would work better in general)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, making it a new model will self-solve the problem, so we can use 'sdpa' in HFMultiModalModel. We can work on refactoring later, it's been long on my list

torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
Expand All @@ -150,47 +326,47 @@
tp_size = get_tensor_model_parallel_world_size()
self.attention_instances = [
Attention(
num_heads=divide(config.num_attention_heads, tp_size),
head_size=config.head_dim,
num_heads=divide(self.text_config.num_attention_heads, tp_size),
head_size=self.text_config.head_dim,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=config.head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
scale=self.text_config.head_dim**-0.5,
num_kv_heads=divide(self.text_config.num_key_value_heads, tp_size),
cache_config=cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
prefix=f"{i}.attn") for i in range(self.text_config.num_hidden_layers)
]

# Model modifications
self.replace_vocab_embed_class(self.model)

# ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
self.lm_head = ParallelLMHead(self.text_config.vocab_size,
self.text_config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
if self.text_config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.vocab_size, logit_scale)
self.sampler = get_sampler()

def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
"""
if (self.config.base_model_tp_plan is None
if (self.text_config.base_model_tp_plan is None
and get_tensor_model_parallel_world_size() > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")

for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items():
for pattern, style in self.text_config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(child_module, style,
Expand All @@ -204,8 +380,8 @@
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
self.text_config.hidden_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
Expand All @@ -222,7 +398,8 @@
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(
input_ids[None, ...],
input_ids[None, ...] if input_ids is not None else None,
inputs_embeds=inputs_embeds[None, ...] if inputs_embeds is not None else None,
use_cache=False,
position_ids=positions[None, ...],
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -252,11 +429,56 @@
loaded_params = set[str]()
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
# In MLLM the head is usually part of the LM so we might want to strip it
# Very bad workaround, needs smth better
if "lm_head" in name:
name = name.replace("language_model.", "")
else:
name = f"{self.model.base_model_prefix}.{name}"
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None and image_embeds is None:
return None

if pixel_values is not None:
vision_embeddings = self.model.get_image_features(
# Thing about pixels being batched again, adding extra dim
# TODO: find out do we really need that extra dim
pixel_values.flatten(0, 1),
vision_feature_layer=self.config.vision_feature_layer,
vision_feature_select_strategy=self.config.vision_feature_select_strategy,
)
return vision_embeddings

if image_embeds is not None:
return image_embeds

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if multimodal_embeddings is not None:
# most supported VLMs merge like this, otherwise we can add a special
# `merge_multimodal_embeddings` method on HF side
mask = (input_ids == self.config.image_token_index)
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
multimodal_embeddings = torch.cat(multimodal_embeddings)

# FIXME: The returned multimodal_embeddings must be either a 3D torch.Tensor of shape
# (num_items, feature_size, hidden_size), or a list / tuple of 2D torch.Tensor’s of shape
# (feature_size, hidden_size), so that multimodal_embeddings[i] retrieves the embeddings generated
# from the i-th multimodal data item (e.g, image) of the request.
inputs_embeds = inputs_embeds.masked_scatter(mask, multimodal_embeddings)
return inputs_embeds
Comment on lines +475 to +484
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know Cyrus has mentioned this, but I'd want to emphasize that for some models we need to consider the padding tokens in-between image tokens also as part of the feature tokens, then when we merge the embeddings, the mask is created on top of all these tokens.

How we handled mistral-format Pixtral is another example for this scenario in addition to Fuyu.

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
return inputs_embeds

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with Fuyu arch, so looking at Pixtral now I don't totally get why the outputs from get_multimodal_embedding contain image special tokens? Isn't it more straightforward to obtain only image related features and merge using mask (ids == special_image_token)?

Or maybe I am missing smth here

9 changes: 9 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,15 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]

def get_hf_inputs(self, modality: str) -> dict[str, NestedTensors]:
modality_items = self._items_by_modality.get(modality, None)
hf_inputs = defaultdict[str, list[NestedTensors]](list)
if modality_items is not None:
for mm_kwargs_item in modality_items:
for key, value in mm_kwargs_item.items():
hf_inputs[key].append(value.data)
hf_inputs = {key: torch.stack(value) for key, value in hf_inputs.items()}
return hf_inputs

MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
Expand Down