Skip to content

Commit

Permalink
🔴 VLM: compile compatibility (#35724)
Browse files Browse the repository at this point in the history
* llavas

* add mroe models

* fix `compile_forward` test for all models

* fix copies

* make style

* also doesn't support cache class

* fix some tests

* not copied from

* ci green?

* fix tests

* fix copies

* fix tests

* check with `numel` and remove `item`

* fix copies

* fix copies

* Update src/transformers/models/cohere2/modeling_cohere2.py

Co-authored-by: Arthur <[email protected]>

* opt remove cross attn

* gemma2

* fixup

* fixup

* fix newly added test

* maybe fixed?

* green please?

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
zucchini-nlp and ArthurZucker authored Feb 14, 2025
1 parent b45cf0e commit 0c78ef6
Show file tree
Hide file tree
Showing 44 changed files with 461 additions and 1,212 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2016,6 +2016,9 @@ def forward(
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
config_class = Blip2Config
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,13 +1284,13 @@ def forward(

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Expand Down Expand Up @@ -701,7 +701,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand Down Expand Up @@ -713,7 +713,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.utils.checkpoint

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand Down Expand Up @@ -550,7 +550,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/got_ocr2/configuration_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
Expand Down Expand Up @@ -161,13 +159,11 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id
Expand Down
85 changes: 2 additions & 83 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
Expand Down Expand Up @@ -748,89 +750,6 @@ def get_image_features(
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids

@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/got_ocr2/modular_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
Expand Down Expand Up @@ -199,13 +197,11 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
_supports_static_cache = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -129,8 +129,8 @@ def forward(

cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
query = torch.cat((query, query_pass), dim=-1).contiguous()
key = torch.cat((key, key_pass), dim=-1).contiguous()

# Cache QKV values
if layer_past is not None:
Expand Down
40 changes: 14 additions & 26 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,7 @@ def forward(
router_logits=all_router_logits,
)

# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand All @@ -1116,13 +1117,8 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand All @@ -1143,7 +1139,6 @@ def _update_causal_mask(
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
Expand All @@ -1154,25 +1149,17 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand All @@ -1182,6 +1169,7 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,9 @@ def forward(
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
config_class = InstructBlipConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: InstructBlipConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,9 @@ def forward(
class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
config_class = InstructBlipVideoConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
Expand Down Expand Up @@ -83,7 +81,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
Expand All @@ -92,7 +89,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
Expand Down
Loading

0 comments on commit 0c78ef6

Please sign in to comment.