Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support VLMs with transformers backend #13754

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

zucchini-nlp
Copy link

@zucchini-nlp zucchini-nlp commented Feb 24, 2025

This PR adds support for multimodal models in Transformers backend. As a start I tested with vanilla LLaVA using demo scripts from the documentation. The generated outputs matched with VLLM outputs.

For this branch to work, we first need a few changes from transformers starting from huggingface/transformers#36367. Currently I want to ask for feedback, if this aligns with how VLLM sees things

cc @Isotr0py @ArthurZucker

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the frontend label Feb 24, 2025

self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
attn_implementation={"text_config": "vllm", "vision_config": "eager"},
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current way doesn't work with LLMs. Setting self.config.get_text_config().attn_implementation="vllm" will be much more generic, but that needs a fix on transformers first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed! Let's work on refactoring vision models! (tho not "that" important because paged is mostly for text, here using flex / sdpa would work better in general)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, making it a new model will self-solve the problem, so we can use 'sdpa' in HFMultiModalModel. We can work on refactoring later, it's been long on my list

@MULTIMODAL_REGISTRY.register_processor(MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersModel(nn.Module, SupportsQuant, SupportsMultiModal):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the same class to support both text-only models and multimodal models makes it difficult to maintain V1 compatibility, see: #13157

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arf, in my mind we would only need 1 "wrapper" to rule them all 😃 But makes sense. We can also have something like is multimodal, but might no work on your side!

@DarkLight1337 DarkLight1337 self-assigned this Feb 24, 2025
@DarkLight1337
Copy link
Member

DarkLight1337 commented Feb 24, 2025

Thanks for working on this! The main difficulty of supporting VLMs is not the model implementation itself, but rather the preprocessing code - vLLM V1 in particular requires precise tracking of the placeholder tokens. I see how generalizing return_token_type_ids to the multimodal context can help. We still have a couple of issues to tackle though:

  • vLLM doesn't support non-consecutive multimodal placeholder feature tokens, which occurs when image tokens are arranged into a grid that is split by other padding tokens. We currently work around this in vLLM by considering those padding tokens as feature tokens as well, as demonstrated in our Fuyu example. To continue to support this workaround, we require information about the padding tokens when return_mm_token_type_ids=True.
  • We need to also distinguish between tokens from different modalities when return_mm_token_type_ids=True. I suggest trying this with LLaVA-OneVision first.
  • After applying HF processor, we split the output BatchFeature entries by their modality in order to cache them. To support this, we can maintain a mapping from the BatchFeature keys to their respective modalities as a class attribute in each HF processor.
  • vLLM supports text+multimodal and token+multimodal inputs, but current HF processors only support the former. Is it feasible to also support token+multimodal inputs in HF Transformers? There are some cases where running tokenizer and HF processor on texts and multimodal data separately doesn't yield the same outputs as passing them together to HF processor, so we need to add special code in vLLM to handle this by overriding _apply_hf_processor_tokens_only.

cc @ywang96

@zucchini-nlp
Copy link
Author

@DarkLight1337 thanks for review! Yeah, checking on more involved models is a good idea to verify all edge cases are covered, will do so. A few clarifications before that:

  • For LLaVA-OneVision do we need to support inputs with video and image within one batch? If yes, that would complicate things a bit probably
  • When doing token+multimodal inputs for the processor, do we expect the tokens to be already expanded or anyhow processed for multimodality? Or we assume that tokens are simply text+tokenizer output? I am not sure if it is a good idea to support it as part of call, but we can add a private method for VLLM to use. Also it would give us freedom to change API without caring for BC in the future, if we decide to handle tokens differently

@DarkLight1337
Copy link
Member

DarkLight1337 commented Feb 24, 2025

For LLaVA-OneVision do we need to support inputs with video and image within one batch? If yes, that would complicate things a bit probably

Yes, we currently support mixed-modality (non-interleaved) inputs and plan to eventually support interleaved-modality inputs as well.

When doing token+multimodal inputs for the processor, do we expect the tokens to be already expanded or anyhow processed for multimodality? Or we assume that tokens are simply text+tokenizer output? I am not sure if it is a good idea to support it as part of call, but we can add a private method for VLLM to use. Also it would give us freedom to change API without caring for BC in the future, if we decide to handle tokens differently

We assume that the tokens have only gone through the tokenizer. So, placeholder tokens still have to be inserted into the input tokens. It's fine if we leave this unsolved for now - we can fall back to detokenizing the tokens back into text before passing them through HF processor.

@ywang96 ywang96 self-assigned this Feb 24, 2025
@ywang96
Copy link
Member

ywang96 commented Feb 24, 2025

Thanks for the PR @zucchini-nlp! I'm a bit occupied at the moment but will take a first pass later tonight.

Copy link
Contributor

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have no say in this but that I am excited! 🚀

@MULTIMODAL_REGISTRY.register_processor(MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersModel(nn.Module, SupportsQuant, SupportsMultiModal):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arf, in my mind we would only need 1 "wrapper" to rule them all 😃 But makes sense. We can also have something like is multimodal, but might no work on your side!


self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
attn_implementation={"text_config": "vllm", "vision_config": "eager"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed! Let's work on refactoring vision models! (tho not "that" important because paged is mostly for text, here using flex / sdpa would work better in general)

Comment on lines +475 to +484
mask = (input_ids == self.config.image_token_index)
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
multimodal_embeddings = torch.cat(multimodal_embeddings)

# FIXME: The returned multimodal_embeddings must be either a 3D torch.Tensor of shape
# (num_items, feature_size, hidden_size), or a list / tuple of 2D torch.Tensor’s of shape
# (feature_size, hidden_size), so that multimodal_embeddings[i] retrieves the embeddings generated
# from the i-th multimodal data item (e.g, image) of the request.
inputs_embeds = inputs_embeds.masked_scatter(mask, multimodal_embeddings)
return inputs_embeds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know Cyrus has mentioned this, but I'd want to emphasize that for some models we need to consider the padding tokens in-between image tokens also as part of the feature tokens, then when we merge the embeddings, the mask is created on top of all these tokens.

How we handled mistral-format Pixtral is another example for this scenario in addition to Fuyu.

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds = self.language_model.get_input_embeddings(image_tokens)
image_token_mask = image_tokens == self.vision_args.image_token_id
image_embeds[image_token_mask] = vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
image_end_mask = image_tokens == self.vision_args.image_end_token_id
split_indices = torch.where(image_end_mask)[0] + 1
if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0)
# If the last split index is the last index in image_tokens, we
# ignore it to avoid empty split tensor
if split_indices[-1] == len(image_tokens):
split_indices = split_indices[:-1]
image_embeds = image_embeds.tensor_split(split_indices.cpu())
return image_embeds
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id,
self.vision_args.image_break_token_id,
self.vision_args.image_end_token_id,
])
return inputs_embeds

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with Fuyu arch, so looking at Pixtral now I don't totally get why the outputs from get_multimodal_embedding contain image special tokens? Isn't it more straightforward to obtain only image related features and merge using mask (ids == special_image_token)?

Or maybe I am missing smth here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants