diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 84f0356cecb2..916631da7e8f 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2016,6 +2016,9 @@ def forward( class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config_class = Blip2Config main_input_name = "pixel_values" + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) def __init__(self, config: Blip2Config): super().__init__(config) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 0a9421409e25..65322e236ca0 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1284,13 +1284,13 @@ def forward( if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) - n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() - n_image_features = image_tokens.shape[0] * image_tokens.shape[1] - if n_image_tokens_in_text != n_image_features: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel(): + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() + n_image_features = image_tokens.shape[0] * image_tokens.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" ) - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 11353a0a990c..75144c65ecff 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -25,7 +25,7 @@ import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -701,7 +701,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, HybridCache): + if isinstance(past_key_values, (HybridCache, StaticCache)): target_length = past_key_values.get_max_cache_shape() else: target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e9fd43c49000..c977f873dc8c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -25,7 +25,7 @@ import torch.nn as nn from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -713,7 +713,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, HybridCache): + if isinstance(past_key_values, (HybridCache, StaticCache)): target_length = past_key_values.get_max_cache_shape() else: target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 4e3c8487c4d8..805e6ba0d2a3 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -20,7 +20,7 @@ import torch.utils.checkpoint from ...activations import ACT2FN -from ...cache_utils import Cache, HybridCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -550,7 +550,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device sequence_length = input_tensor.shape[1] - if isinstance(past_key_values, HybridCache): + if isinstance(past_key_values, (HybridCache, StaticCache)): target_length = past_key_values.get_max_cache_shape() else: target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] diff --git a/src/transformers/models/got_ocr2/configuration_got_ocr2.py b/src/transformers/models/got_ocr2/configuration_got_ocr2.py index 480252ab1471..fb9a1fb68889 100644 --- a/src/transformers/models/got_ocr2/configuration_got_ocr2.py +++ b/src/transformers/models/got_ocr2/configuration_got_ocr2.py @@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 151859): The image token index to encode the image prompt. image_seq_length (`int`, *optional*, defaults to 576): @@ -161,13 +159,11 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=151859, image_seq_length=576, pad_token_id=-1, **kwargs, ): - self.ignore_index = ignore_index self.image_token_index = image_token_index self.image_seq_length = image_seq_length self.pad_token_id = pad_token_id diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 957e05bea75a..86598ac08965 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): # important: this ported version of GotOcr2 isn't meant for training from scratch - only @@ -748,89 +750,6 @@ def get_image_features( image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - if left_padding: - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) - image_to_overwrite &= padding_mask - - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] - - final_embedding[batch_indices, indices_to_mask] = 0 - - if labels is None: - final_labels = None - - return final_embedding, final_attention_mask, final_labels, position_ids - @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 899075683eb4..fff434ead2e9 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 151859): The image token index to encode the image prompt. image_seq_length (`int`, *optional*, defaults to 576): @@ -199,13 +197,11 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=151859, image_seq_length=576, pad_token_id=-1, **kwargs, ): - self.ignore_index = ignore_index self.image_token_index = image_token_index self.image_seq_length = image_seq_length self.pad_token_id = pad_token_id diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index d5153fb3f828..10b6efbc5943 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_quantized_cache = True - _supports_static_cache = False # TODO (fix me): compilation fails due to a stide error? + _supports_static_cache = True def _init_weights(self, module): """Initialize the weights""" @@ -129,8 +129,8 @@ def forward( cos, sin = position_embeddings query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - query = torch.cat((query, query_pass), dim=-1) - key = torch.cat((key, key_pass), dim=-1) + query = torch.cat((query, query_pass), dim=-1).contiguous() + key = torch.cat((key, key_pass), dim=-1).contiguous() # Cache QKV values if layer_past is not None: diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index d877b8323b3b..546e78eac148 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1108,6 +1108,7 @@ def forward( router_logits=all_router_logits, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1116,13 +1117,8 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -1143,7 +1139,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -1154,25 +1149,17 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - if attention_mask is not None and attention_mask.dim() == 4: - # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -1182,6 +1169,7 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index b705da44eba4..ea42d65b845c 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1290,6 +1290,9 @@ def forward( class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): config_class = InstructBlipConfig main_input_name = "pixel_values" + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) def __init__(self, config: InstructBlipConfig): super().__init__(config) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index dcf77863a149..5183a3c22faf 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1284,6 +1284,9 @@ def forward( class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): config_class = InstructBlipVideoConfig main_input_name = "pixel_values" + _supports_cache_class = True + _supports_static_cache = True + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) def __init__(self, config: InstructBlipVideoConfig): super().__init__(config) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index d2a3e9747b66..f476591b2eb6 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32000): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): @@ -83,7 +81,6 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=32000, projector_hidden_act="gelu", vision_feature_select_strategy="default", @@ -92,7 +89,6 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): - self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 36f212e76844..610ab417d92b 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -28,6 +28,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -136,6 +137,8 @@ class LlavaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): # important: this ported version of Llava isn't meant for training from scratch - only @@ -321,89 +324,6 @@ def get_image_features( image_features = self.multi_modal_projector(selected_image_feature) return image_features - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - if left_padding: - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) - image_to_overwrite &= padding_mask - - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] - - final_embedding[batch_indices, indices_to_mask] = 0 - - if labels is None: - final_labels = None - - return final_embedding, final_attention_mask, final_labels, position_ids - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -499,14 +419,14 @@ def forward( image_sizes=image_sizes, ) - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/llava_next/configuration_llava_next.py b/src/transformers/models/llava_next/configuration_llava_next.py index 2610275cedfd..3836dbf71cd2 100644 --- a/src/transformers/models/llava_next/configuration_llava_next.py +++ b/src/transformers/models/llava_next/configuration_llava_next.py @@ -36,8 +36,6 @@ class LlavaNextConfig(PretrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32000): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): @@ -88,7 +86,6 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=32000, projector_hidden_act="gelu", vision_feature_select_strategy="default", @@ -99,7 +96,6 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): - self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 06e1cc63940f..3cdf1b348404 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -31,6 +31,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -245,6 +246,8 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): # important: this ported version of LlavaNext isn't meant for training from scratch - only @@ -405,245 +408,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def _merge_input_ids_with_image_features( - self, - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids=None, - labels=None, - image_token_index=None, - ignore_index=-100, - ): - """ - Merge input_ids with with image features into final embeddings - - Args: - image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): - All vision vectors of all images in the batch - feature_lens (`torch.LongTensor` of shape `(num_images)`): - The length of visual embeddings of each image as stacked in `image_features` - inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): - Token embeddings before merging with visual embeddings - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Input_ids of tokens, possibly filled with image token - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Mask to avoid performing attention on padding token indices. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) - :abels need to be recalculated to support training (if provided) - image_token_index (`int`, *optional*) - Token id used to indicate the special "image" token. Defaults to `config.image_token_index` - ignore_index (`int`, *optional*) - Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. - Returns: - final_embedding, final_attention_mask, position_ids, final_labels - - Explanation: - each image has variable length embeddings, with length specified by feature_lens - image_features is concatenation of all visual embed vectors - task: fill each with the correct number of visual embeddings - Example: - X (5 patches), Y (3 patches), Z (8) - X, Y are in the same sequence (in-context learning) - if right padding - input_ids: [ - a b c d e f X g h i j k Y l m - o p q r Z s t u v _ _ _ _ _ _ - ] - input_ids should be: [ - a b c d e f X X X X X g h i j k Y Y Y l m - o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ - ] - labels should be: [ - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m - o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ - ] - elif left padding - input_ids: [ - a b c d e f X g h i j k Y l m - _ _ _ _ _ _ o p q r Z s t u v - ] - input_ids should be: [ - a b c d e f X X X X X g h i j k Y Y Y l m - _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v - ] - labels should be: [ - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m - _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v - ] - Edge cases: - * If tokens are same but image token sizes are different, then cannot infer left or right padding - ```python - cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw) - prompts = [ - "[INST] \nWhat is shown in this image? [/INST]", - "[INST] \nWhat is shown in this image? [/INST]", - ] - inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") - chart_img has 2634 tokens, while cat_img has 2340 tokens - ``` - - input_ids: [ - a b c d X g h - i j Y k l m n - ] - where X is 3 tokens while Y is 5, this mean after merge - if left-padding (batched generation) - input_ids should be: [ - _ _ a b c d X X X g h - i j Y Y Y Y Y k l m n - ] - elif (right padding) (training) - input_ids should be: [ - a b c d X X X g h _ _ - i j Y Y Y Y Y k l m n - ] - """ - image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index - ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index - - if self.training and self.padding_side == "left": - logger.warning_once( - "Padding side is set to 'left' but the model is in training mode. For training " - "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " - "If that's intended, ignore this warning" - ) - if not self.training and self.padding_side == "right": - logger.warning_once( - "Padding side is set to 'right' but the model is in inference mode. For correct " - "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " - "If that's intended, ignore this warning" - ) - - with torch.no_grad(): - # ! in llava 1.6, number of patches is variable - num_images = feature_lens.size(0) - num_image_features, embed_dim = image_features.shape - if feature_lens.sum() != num_image_features: - raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") - batch_size = input_ids.shape[0] - _left_padding = torch.any(attention_mask[:, 0] == 0) - _right_padding = torch.any(attention_mask[:, -1] == 0) - - left_padding = self.padding_side == "left" - if batch_size > 1: - if _left_padding and _right_padding: - raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") - elif _right_padding and left_padding: - left_padding = False - elif _left_padding and not left_padding: - left_padding = True - - # Whether to turn off right padding - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == image_token_index - # special_image_token_mask: [bsz, seqlen] - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # num_special_image_tokens: [bsz] - # Reserve for padding of num_images - total_num_special_image_tokens = torch.sum(special_image_token_mask) - if total_num_special_image_tokens != num_images: - raise ValueError( - f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})." - ) - # Compute the maximum embed dimension - # max_image_feature_lens is max_feature_lens per batch - feature_lens = feature_lens.to(input_ids.device) - feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) - feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) - embed_sequence_lengths = ( - (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum - ) - max_embed_dim = embed_sequence_lengths.max() - - batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - # ! instead of special_image_token_mask * (num_image_patches - 1) - # special_image_token_mask * (num_feature_len - 1) - special_image_token_mask = special_image_token_mask.long() - special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 - new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 - if left_padding: - # shift right token positions so that they are ending at the same number - # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] - new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] - - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - final_input_ids = torch.full( - (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - input_ids = input_ids.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] - final_labels = None - if labels is not None: - labels = labels.to(target_device) - final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - with torch.no_grad(): - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) - embed_indices = embed_indices.expand(batch_size, max_embed_dim) - embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) - - if left_padding: - # exclude padding on the left - max_embed_dim = max_embed_dim.to(target_device) - val = (max_embed_dim - embed_indices) <= embed_seq_lens - else: - # exclude padding on the right - val = embed_indices < embed_seq_lens - image_to_overwrite &= val - - if image_to_overwrite.sum() != num_image_features: - raise ValueError( - f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " - f"The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. " - f"This prevents correct indexing and breaks batch generation." - ) - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -875,14 +639,14 @@ def forward( image_newline=self.image_newline, ) - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index 6b85ebb4455e..01450f6b587c 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -38,8 +38,6 @@ class LlavaNextVideoConfig(PretrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32001): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): @@ -96,7 +94,6 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=32001, projector_hidden_act="gelu", multimodal_projector_bias=True, @@ -116,7 +113,6 @@ def __init__( self.spatial_pool_stride = spatial_pool_stride self.image_seq_length = image_seq_length self.video_seq_length = video_seq_length - self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.multimodal_projector_bias = multimodal_projector_bias diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index f62824947ddf..9ce88c541231 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -32,7 +32,13 @@ from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel, AutoModelForCausalLM from .configuration_llava_next_video import LlavaNextVideoConfig @@ -153,6 +159,8 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only @@ -440,245 +448,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def _merge_input_ids_with_image_features( - self, - image_features, - feature_lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids=None, - labels=None, - image_token_index=None, - ignore_index=-100, - ): - """ - Merge input_ids with with image features into final embeddings - - Args: - image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): - All vision vectors of all images in the batch - feature_lens (`torch.LongTensor` of shape `(num_images)`): - The length of visual embeddings of each image as stacked in `image_features` - inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): - Token embeddings before merging with visual embeddings - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Input_ids of tokens, possibly filled with image token - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Mask to avoid performing attention on padding token indices. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) - :abels need to be recalculated to support training (if provided) - image_token_index (`int`, *optional*) - Token id used to indicate the special "image" token. Defaults to `config.image_token_index` - ignore_index (`int`, *optional*) - Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. - Returns: - final_embedding, final_attention_mask, position_ids, final_labels - - Explanation: - each image has variable length embeddings, with length specified by feature_lens - image_features is concatenation of all visual embed vectors - task: fill each with the correct number of visual embeddings - Example: - X (5 patches), Y (3 patches), Z (8) - X, Y are in the same sequence (in-context learning) - if right padding - input_ids: [ - a b c d e f X g h i j k Y l m - o p q r Z s t u v _ _ _ _ _ _ - ] - input_ids should be: [ - a b c d e f X X X X X g h i j k Y Y Y l m - o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ - ] - labels should be: [ - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m - o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ - ] - elif left padding - input_ids: [ - a b c d e f X g h i j k Y l m - _ _ _ _ _ _ o p q r Z s t u v - ] - input_ids should be: [ - a b c d e f X X X X X g h i j k Y Y Y l m - _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v - ] - labels should be: [ - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m - _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v - ] - Edge cases: - * If tokens are same but image token sizes are different, then cannot infer left or right padding - ```python - cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw) - prompts = [ - "[INST] \nWhat is shown in this image? [/INST]", - "[INST] \nWhat is shown in this image? [/INST]", - ] - inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") - chart_img has 2634 tokens, while cat_img has 2340 tokens - ``` - - input_ids: [ - a b c d X g h - i j Y k l m n - ] - where X is 3 tokens while Y is 5, this mean after merge - if left-padding (batched generation) - input_ids should be: [ - _ _ a b c d X X X g h - i j Y Y Y Y Y k l m n - ] - elif (right padding) (training) - input_ids should be: [ - a b c d X X X g h _ _ - i j Y Y Y Y Y k l m n - ] - """ - image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index - ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index - - if self.training and self.padding_side == "left": - logger.warning_once( - "Padding side is set to 'left' but the model is in training mode. For training " - "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " - "If that's intended, ignore this warning" - ) - if not self.training and self.padding_side == "right": - logger.warning_once( - "Padding side is set to 'right' but the model is in inference mode. For correct " - "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " - "If that's intended, ignore this warning" - ) - - with torch.no_grad(): - # ! in llava 1.6, number of patches is variable - num_images = feature_lens.size(0) - num_image_features, embed_dim = image_features.shape - if feature_lens.sum() != num_image_features: - raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") - batch_size = input_ids.shape[0] - _left_padding = torch.any(attention_mask[:, 0] == 0) - _right_padding = torch.any(attention_mask[:, -1] == 0) - - left_padding = self.padding_side == "left" - if batch_size > 1: - if _left_padding and _right_padding: - raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") - elif _right_padding and left_padding: - left_padding = False - elif _left_padding and not left_padding: - left_padding = True - - # Whether to turn off right padding - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == image_token_index - # special_image_token_mask: [bsz, seqlen] - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # num_special_image_tokens: [bsz] - # Reserve for padding of num_images - total_num_special_image_tokens = torch.sum(special_image_token_mask) - if total_num_special_image_tokens != num_images: - raise ValueError( - f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})." - ) - # Compute the maximum embed dimension - # max_image_feature_lens is max_feature_lens per batch - feature_lens = feature_lens.to(input_ids.device) - feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) - feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) - embed_sequence_lengths = ( - (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum - ) - max_embed_dim = embed_sequence_lengths.max() - - batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - # ! instead of special_image_token_mask * (num_image_patches - 1) - # special_image_token_mask * (num_feature_len - 1) - special_image_token_mask = special_image_token_mask.long() - special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 - new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 - if left_padding: - # shift right token positions so that they are ending at the same number - # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] - new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] - - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - final_input_ids = torch.full( - (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - input_ids = input_ids.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] - final_labels = None - if labels is not None: - labels = labels.to(target_device) - final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - with torch.no_grad(): - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) - embed_indices = embed_indices.expand(batch_size, max_embed_dim) - embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) - - if left_padding: - # exclude padding on the left - max_embed_dim = max_embed_dim.to(target_device) - val = (max_embed_dim - embed_indices) <= embed_seq_lens - else: - # exclude padding on the right - val = embed_indices < embed_seq_lens - image_to_overwrite &= val - - if image_to_overwrite.sum() != num_image_features: - raise ValueError( - f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " - f"The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. " - f"This prevents correct indexing and breaks batch generation." - ) - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids - def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): """ Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. @@ -948,14 +717,14 @@ def forward( image_newline=self.image_newline, ) - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) @@ -970,14 +739,14 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index b2e06c337c1b..8769f8db4131 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -30,6 +30,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import ( + is_torchdynamo_compiling, logging, ) from ..auto import CONFIG_MAPPING, AutoConfig @@ -52,8 +53,6 @@ class LlavaNextVideoConfig(PretrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): The config object or dictionary of the text backbone. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32001): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): @@ -110,7 +109,6 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=32001, projector_hidden_act="gelu", multimodal_projector_bias=True, @@ -130,7 +128,6 @@ def __init__( self.spatial_pool_stride = spatial_pool_stride self.image_seq_length = image_seq_length self.video_seq_length = video_seq_length - self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.multimodal_projector_bias = multimodal_projector_bias @@ -479,14 +476,14 @@ def forward( image_newline=self.image_newline, ) - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) @@ -501,14 +498,14 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index ed584bda7f5d..e86ce394e13d 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -30,6 +30,7 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, + is_torchdynamo_compiling, logging, ) from ...utils.deprecation import deprecate_kwarg @@ -250,7 +251,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True - _supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support + _supports_static_cache = True _supports_quantized_cache = True _supports_sdpa = True @@ -712,19 +713,15 @@ def forward( image_newline=self.image_newline, vision_aspect_ratio=vision_aspect_ratio, ) - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) @@ -741,18 +738,14 @@ def forward( video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] - if n_video_tokens != n_video_features: + special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_index).sum() + n_video_features = video_features.shape[0] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - special_video_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 1969acf2f5b1..f1f1ef1821c7 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,10 +22,10 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, + AttentionMaskConverter, ) from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -98,6 +98,7 @@ class OPTAttention(nn.Module): def __init__( self, config: OPTConfig, + layer_idx: int = None, **kwargs, ): super().__init__() @@ -106,6 +107,13 @@ def __init__( self.num_heads = config.num_attention_heads self.dropout = config.attention_dropout self.enable_bias = config.enable_bias + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.head_dim = self.embed_dim // self.num_heads self.is_causal = True @@ -122,9 +130,6 @@ def __init__( self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -134,52 +139,33 @@ def forward( output_attentions: bool = False, # isn't needed in normal attention, but needed in flash attention so to keep the signature same position_ids: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) + attn_weights = torch.matmul(query_states, key_states.transpose(3, 2)) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = torch.max( - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) - ) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 - if attn_weights.dtype == torch.float16: - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): @@ -187,39 +173,19 @@ def forward( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_probs, past_key_value class OptFlashAttention2(OPTAttention): @@ -245,33 +211,33 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, position_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, _, _ = hidden_states.size() - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + bsz, query_length, _ = hidden_states.size() - past_key_value = (key_states, value_states) + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) - query_length = query_states.shape[1] - tgt_len = key_states.shape[-2] + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) - key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) attn_dropout = self.dropout if self.training else 0.0 + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. @@ -331,6 +297,7 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, position_ids: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions or layer_head_mask is not None: logger.warning_once( @@ -344,24 +311,24 @@ def forward( layer_head_mask=layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, - ) # TODO after merge add position_ids=position_ids + cache_position=cache_position, + ) bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) * self.scaling - query_states = self._shape(query_states, -1, bsz) - - # get key, value proj - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - past_key_value = (key_states, value_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) causal_mask = attention_mask if attention_mask is not None: @@ -378,10 +345,6 @@ def forward( attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, - # this model uses the scaling factor in the query projection for some reason, but not in Q@K^T - # so we need to scale to remove scaling in SDPA to have similar results with eager. - # Maybe needs a change in the model to remove scaling in query projection - scale=1.0, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -399,11 +362,11 @@ def forward( class OPTDecoderLayer(nn.Module): - def __init__(self, config: OPTConfig): + def __init__(self, config: OPTConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config) + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.do_layer_norm_before = config.do_layer_norm_before self.dropout = config.dropout @@ -425,6 +388,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -440,6 +404,8 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence.. """ residual = hidden_states @@ -456,6 +422,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -524,6 +491,9 @@ class OPTPreTrainedModel(PreTrainedModel): _no_split_modules = ["OPTDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -601,6 +571,10 @@ def _init_weights(self, module): config.n_positions - 1]`. for padding use -1. [What are position IDs?](../glossary#position-ids) + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -643,9 +617,7 @@ def __init__(self, config: OPTConfig): else: self.final_layer_norm = None - self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" + self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -657,48 +629,130 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, - inputs_embeds: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, ): - """ - Updates the causal mask for the decoder. - """ - batch_size, seq_length = input_shape - mask_seq_length = past_key_values_length + seq_length - if self._use_flash_attention_2: - # 2d mask is passed through the layers - causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - attention_mask = ( - torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if attention_mask is None - else attention_mask + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) - return causal_attention_mask, attention_mask + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)" - ) - if self._use_sdpa and not output_attentions and head_mask is None: - causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask else: - causal_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) - return causal_attention_mask, attention_mask + return causal_mask def forward( self, @@ -712,6 +766,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: r""" Args: @@ -764,6 +819,10 @@ def forward( config.n_positions - 1]`. for padding use -1. [What are position IDs?](../glossary#position-ids) + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -773,51 +832,65 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if past_key_values is None: + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " + "You should pass an instance of `DynamicCache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if attention_mask is None: + seq_length = past_seen_tokens + inputs_embeds.shape[1] + attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device) - causal_attention_mask, attention_mask = self._update_causal_mask( - inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions + # embed positions if position_ids is None: + # position_ids = cache_position.unsqueeze(0) position_ids = torch.cumsum(attention_mask, dim=1) position_ids = (position_ids * attention_mask - 1).long() - # cut positions if `past_key_values_length` is > 0 - position_ids = position_ids[:, past_key_values_length:] + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_seen_tokens:] - pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids) if self.project_in is not None: inputs_embeds = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): @@ -838,34 +911,34 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_attention_mask, + causal_mask, head_mask[idx] if head_mask is not None else None, None, output_attentions, use_cache, position_ids, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_attention_mask, + attention_mask=causal_mask, position_ids=position_ids, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -881,6 +954,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -930,6 +1006,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -950,6 +1027,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1008,6 +1086,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1069,6 +1148,10 @@ def forward( config.n_positions - 1]`. for padding use -1. [What are position IDs?](../glossary#position-ids) + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. Returns: @@ -1107,6 +1190,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]).contiguous() diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 9172b98c069e..35ad047a00dd 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -29,6 +29,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -508,7 +509,7 @@ def forward( special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - if inputs_embeds[special_image_mask].numel() != image_features.numel(): + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) raise ValueError( f"Number of images does not match number of special image tokens in the input text. " diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py index becd20040332..e761481d8259 100644 --- a/src/transformers/models/video_llava/configuration_video_llava.py +++ b/src/transformers/models/video_llava/configuration_video_llava.py @@ -38,8 +38,6 @@ class VideoLlavaConfig(PretrainedConfig): text_config (`Union[AutoConfig, dict]`, *optional*): The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. Defaults to `LlamaConfig` if not indicated. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32000): The image token index to encode the image prompt. video_token_index (`int`, *optional*, defaults to 32001): @@ -88,7 +86,6 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=32000, video_token_index=32001, projector_hidden_act="gelu", @@ -99,7 +96,6 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): - self.ignore_index = ignore_index self.image_token_index = image_token_index self.video_token_index = video_token_index self.projector_hidden_act = projector_hidden_act diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index d8da974b9862..ba4de6537442 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -28,6 +28,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -137,6 +138,8 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = ( @@ -276,92 +279,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.language_model.get_decoder() - def _merge_input_ids_with_visual_features( - self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1 - ): - num_images, num_image_patches, embed_dim = visual_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index - - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == special_vision_token - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != special_vision_token) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = ( - torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1 - ) - nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - # expand input ids so that the second "merge" with videos does not fail - final_embedding = torch.zeros( - batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - final_input_ids = torch.full( - (batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] - if labels is not None: - final_labels = torch.full( - (batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - else: - final_labels = None - - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling - image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device) - image_to_overwrite[batch_indices, text_to_overwrite] = False - if left_padding: - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) - image_to_overwrite &= padding_mask - - if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): - visual_type = "videos" if num_frames == 8 else "images" - num_images //= num_frames - raise ValueError( - f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while" - f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids - def get_image_features( self, pixel_values_images: torch.FloatTensor, @@ -579,14 +496,14 @@ def forward( vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) @@ -595,14 +512,14 @@ def forward( pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] * video_features.shape[1] - if n_video_tokens != n_video_features: + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): + n_video_tokens = (input_ids == self.config.video_token_index).sum() + n_video_features = video_features.shape[0] * video_features.shape[1] raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) diff --git a/src/transformers/models/vipllava/configuration_vipllava.py b/src/transformers/models/vipllava/configuration_vipllava.py index 94d890c4b84e..ac24cce24129 100644 --- a/src/transformers/models/vipllava/configuration_vipllava.py +++ b/src/transformers/models/vipllava/configuration_vipllava.py @@ -37,8 +37,6 @@ class VipLlavaConfig(PretrainedConfig): Custom vision config or dict text_config (`Union[AutoConfig, dict]`, *optional*): The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. - ignore_index (`int`, *optional*, defaults to -100): - The ignore index for the loss function. image_token_index (`int`, *optional*, defaults to 32000): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): @@ -78,7 +76,6 @@ def __init__( self, vision_config=None, text_config=None, - ignore_index=-100, image_token_index=32000, projector_hidden_act="gelu", projector_layernorm_eps=1e-5, @@ -86,7 +83,6 @@ def __init__( image_seq_length=576, **kwargs, ): - self.ignore_index = ignore_index self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.projector_layernorm_eps = projector_layernorm_eps diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 71201db2098e..ef4b3bff3958 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -28,6 +28,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -137,6 +138,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): # important: this ported version of VipLlava isn't meant for training from scratch - only @@ -297,89 +300,6 @@ def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_lay image_features = self.multi_modal_projector(image_features) return image_features - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): - num_images, num_image_patches, embed_dim = image_features.shape - batch_size, sequence_length = input_ids.shape - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - # 1. Create a mask to know where special image tokens are - special_image_token_mask = input_ids == self.config.image_token_index - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) - # Compute the maximum embed dimension - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) - - # 2. Compute the positions where text should be written - # Calculate new positions for text tokens in merged image-text sequence. - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. - # `torch.cumsum` computes how each image token shifts subsequent text token positions. - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] - if left_padding: - new_token_positions += nb_image_pad[:, None] # offset for left padding - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] - - # 3. Create the full embedding, already padded to the maximum position - final_embedding = torch.zeros( - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - final_attention_mask = torch.zeros( - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device - ) - if labels is not None: - final_labels = torch.full( - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device - ) - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually - # set the corresponding tensors into their correct target device. - target_device = inputs_embeds.device - batch_indices, non_image_indices, text_to_overwrite = ( - batch_indices.to(target_device), - non_image_indices.to(target_device), - text_to_overwrite.to(target_device), - ) - attention_mask = attention_mask.to(target_device) - - # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] - if labels is not None: - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] - - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) - image_to_overwrite = torch.full( - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device - ) - image_to_overwrite[batch_indices, text_to_overwrite] = False - if left_padding: - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) - else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) - image_to_overwrite &= padding_mask - - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): - raise ValueError( - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." - ) - - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) - final_attention_mask |= image_to_overwrite - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) - - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) - indices_to_mask = new_token_positions[batch_indices, pad_indices] - - final_embedding[batch_indices, indices_to_mask] = 0 - - if labels is None: - final_labels = None - - return final_embedding, final_attention_mask, final_labels, position_ids - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -469,14 +389,14 @@ def forward( pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_index).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ce31cc844f19..3b9700dc20c9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1783,12 +1783,12 @@ def test_generate_from_inputs_embeds_with_static_cache(self): model.config.use_cache = True model.config.is_decoder = True batch_size = input_ids.shape[0] - max_length = 30 + max_new_tokens = 10 # here we force to not stop at eos and go until max-length model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 generation_kwargs = { - "max_length": max_length, + "max_new_tokens": max_new_tokens, "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` } @@ -1811,10 +1811,11 @@ def test_generate_from_inputs_embeds_with_static_cache(self): # we should get `max_length - 1` in shape, not `max_length - embeds_length`. # -1 because the last generated token isn't yet in the cache. - cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim) - self.assertTrue(isinstance(outputs.past_key_values, StaticCache)) - self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape) + max_length = max_new_tokens + inputs_embeds.shape[1] - 1 + cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] + self.assertIsInstance(outputs.past_key_values, StaticCache) + self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers) + self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) @pytest.mark.generate def test_generate_continue_from_past_key_values(self): @@ -2022,7 +2023,7 @@ def test_generate_with_static_cache(self): config.is_decoder = True batch_size = main_input.shape[0] - seq_length = main_input.shape[-1] + seq_length = self.model_tester.seq_length max_new_tokens = 20 for dtype in (torch.float32, torch.float16): @@ -2134,7 +2135,15 @@ def test_generate_compile_model_forward(self): # compilation-specific setup torch.compiler.reset() # prevent cached compilation from being used in the test has_defined_cache_implementation = model.generation_config.cache_implementation is not None - model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) + + # BLIP is the only exception with custom generate which call `self.lm.generate()` + # We should avoid such calls in all subsequent multimodal models and try to make `generate()` + # compatible with multimodality + if "blip" in model.__class__.__name__.lower(): + model.language_model.generation_config.compile_config._compile_all_devices = True + else: + # force compilation (e.g. fast CI, CPU + model.generation_config.compile_config._compile_all_devices = True generation_kwargs = { "do_sample": False, @@ -2175,7 +2184,14 @@ def test_generate_compile_model_forward(self): ) self.assertFalse(isinstance(decoder_cache, DynamicCache)) self.assertTrue(decoder_cache.is_compileable) - self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called + + # BLIP is the only exception with custom generate which call `self.lm.generate()` + # We should avoid such calls in all subsequent multimodal models and try to make `generate()` + # compatible with multimodality + if "blip" in model.__class__.__name__.lower(): + self.assertTrue(hasattr(model.language_model, "_compiled_call")) + else: + self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): self._check_similar_generate_outputs(dynamic_result, compiled_result) @@ -2198,9 +2214,19 @@ def test_generate_compilation_all_outputs(self): # compilation-specific setup torch.compiler.reset() # prevent cached compilation from being used in the test has_defined_cache_implementation = model.generation_config.cache_implementation is not None - model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) - if not has_defined_cache_implementation: - model.generation_config.cache_implementation = "static" + + # BLIP is the only exception with custom generate which call `self.lm.generate()` + # We should avoid such calls in all subsequent multimodal models and try to make `generate()` + # compatible with multimodality + if "blip" in model.__class__.__name__.lower(): + model.language_model.generation_config.compile_config._compile_all_devices = True + if not has_defined_cache_implementation: + model.language_model.generation_config.cache_implementation = "static" + else: + # force compilation (e.g. fast CI, CPU) + model.generation_config.compile_config._compile_all_devices = True + if not has_defined_cache_implementation: + model.generation_config.cache_implementation = "static" logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) output_generate = model.generate( @@ -2218,8 +2244,10 @@ def test_generate_compilation_all_outputs(self): **inputs_dict, ) - # Sanity check: compilation has happened - self.assertTrue(hasattr(model, "_compiled_call")) + if "blip" in model.__class__.__name__.lower(): + self.assertTrue(hasattr(model.language_model, "_compiled_call")) + else: + self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index f12ff24b17f1..8b5e62de14c7 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -286,10 +286,18 @@ def test_generate_from_inputs_embeds_0_greedy(self): def test_generate_from_inputs_embeds_1_beam_search(self): pass - @unittest.skip(reason="Unsupported") + @unittest.skip(reason="Dynamic control flow due to MoE") def test_generate_with_static_cache(self): pass + @unittest.skip(reason="Dynamic control flow due to MoE") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip(reason="Dynamic control flow due to MoE") + def test_generate_compile_model_forward(self): + pass + @require_torch class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index e26232e3eb43..a405a1f97fb3 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -816,6 +816,10 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): def test_generate_from_inputs_embeds(self, _, num_beams): pass + @unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + # this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py class Blip2TextModelTester: diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index 4563cc17dfce..491fd9f9ec4f 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -386,10 +386,6 @@ def test_disk_offload_bin(self): def test_cpu_offload(self): pass - @unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme - def test_custom_4d_attention_mask(self): - pass - @unittest.skip("VQ-VAE module doesn't initialize weights properly") def test_initialization(self): pass diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index ac044de5ca96..178bec98ac62 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -256,12 +256,6 @@ def test_generate_from_inputs_embeds_with_static_cache(self): def test_past_key_values_format(self): pass - @unittest.skip( - reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step" - ) - def test_generate_compile_1_end_to_end(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 5d19f5b02025..32c45d6e71f7 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -838,6 +838,14 @@ def test_contrastive_generate_low_memory(self): def test_custom_4d_attention_mask(self): pass + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") + def test_generate_with_static_cache(self): + pass + + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") + def test_generate_compile_model_forward(self): + pass + @unittest.skip(reason="We only test the model that takes in multiple images") def test_model(self): pass diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index e072499ad3f1..bbf877289040 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -530,6 +530,12 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + "InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present" + ) + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 0534b4f5ea73..351dea3d6fae 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -546,6 +546,12 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + "InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present" + ) + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 25e1a747ce9f..b47423a02ec7 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -316,14 +316,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Compile not yet supported because in LLava models") - def test_sdpa_can_compile_dynamic(self): - pass - - @unittest.skip(reason="Compile not yet supported because in LLava models") - def test_sdpa_can_dispatch_on_flash(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index eaeda3cecb7b..0c75df53c1bb 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -365,22 +365,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Feedforward chunking is not yet supported") - def test_feed_forward_chunking(self): - pass - - @unittest.skip(reason="CPU offload is not yet supported") - def test_cpu_offload(self): - pass - - @unittest.skip(reason="Compile not yet supported because in LLava models") - def test_sdpa_can_compile_dynamic(self): - pass - - @unittest.skip(reason="Compile not yet supported because in LLava models") - def test_sdpa_can_dispatch_on_flash(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass @@ -391,6 +375,10 @@ def test_flash_attn_2_fp32_ln(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("LLaVA Next has dynamic control flow in unpadding") + def test_generate_compile_model_forward(self): + pass + @require_torch class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 0f4642402644..6d4df92f5c22 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -382,26 +382,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Feedforward chunking is not yet supported") - def test_feed_forward_chunking(self): - pass - - @unittest.skip(reason="CPU offload is not yet supported") - def test_cpu_offload(self): - pass - - @unittest.skip( - reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)" - ) - def test_sdpa_can_compile_dynamic(self): - pass - - @unittest.skip( - reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)" - ) - def test_sdpa_can_dispatch_on_flash(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass @@ -412,6 +392,10 @@ def test_flash_attn_2_fp32_ln(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("LLaVA Next Video has dynamic control flow in unpadding") + def test_generate_compile_model_forward(self): + pass + @require_torch class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 63be10a774db..c9bb448278e7 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -346,6 +346,10 @@ def test_flash_attn_2_fp32_ln(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("LLaVA OneVision has dynamic control flow in unpadding") + def test_generate_compile_model_forward(self): + pass + @require_torch class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 3c3256da8b24..994d88444809 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -540,7 +540,6 @@ def prepare_config_and_inputs_for_common(self): "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, - "use_cache": False, } return config, inputs_dict diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 3e3d2159a022..dad740cde721 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -81,7 +81,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=50, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -89,7 +89,6 @@ def __init__( num_labels=3, word_embed_proj_dim=16, type_sequence_label_size=2, - attn_implementation="eager", ): self.parent = parent self.batch_size = batch_size @@ -113,7 +112,6 @@ def __init__( self.type_sequence_label_size = type_sequence_label_size self.word_embed_proj_dim = word_embed_proj_dim self.is_encoder_decoder = False - self.attn_implementation = attn_implementation def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( @@ -143,7 +141,6 @@ def get_config(self): embed_dim=self.embed_dim, is_encoder_decoder=False, word_embed_proj_dim=self.word_embed_proj_dim, - attn_implementation=self.attn_implementation, ) def get_pipeline_config(self): diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 9886684d6088..a0439550f8f0 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -545,7 +545,6 @@ def prepare_config_and_inputs_for_common(self): "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, - "use_cache": False, } return config, inputs_dict diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index b8d4d4167e57..528f125693f7 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -226,14 +226,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`") - def test_sdpa_can_compile_dynamic(self): - pass - - @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`") - def test_sdpa_can_dispatch_on_flash(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index f6a601c8a02d..24f99d4b0b18 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -306,14 +306,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip(reason="Compile not yet supported because it is not yet supported in LLava") - def test_sdpa_can_compile_dynamic(self): - pass - - @unittest.skip(reason="Compile not yet supported because in LLava models") - def test_sdpa_can_dispatch_on_flash(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9dd5877c8b90..a707b25a3110 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4324,10 +4324,6 @@ def test_sdpa_can_dispatch_on_flash(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() inputs_dict = self._prepare_for_class(inputs_dict, model_class) - if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]: - self.skipTest( - reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input" - ) if config.model_type in ["paligemma"]: self.skipTest( "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input" @@ -4778,6 +4774,9 @@ def test_custom_4d_attention_mask(self): model = model_class(config).to(device=torch_device, dtype=torch.float32) set_model_for_less_flaky_test(model) + if "position_ids" not in inspect.signature(model.forward).parameters: + continue # this model doesn't accept position ids as input + ( input_ids, position_ids,