From 33990037187fc477e78d7db477f0d656cb12f2a8 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 27 Jul 2024 14:25:34 +0200 Subject: [PATCH 1/8] Upgrade Transformers to v4.43.x --- hf_transformers | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hf_transformers b/hf_transformers index fc35907f95..47c29ccfaf 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit fc35907f95459d7a6c5281dfadd680b6f7b620e3 +Subproject commit 47c29ccfaf56947d845971a439cbe75a764b63d7 diff --git a/setup.py b/setup.py index 78c526fdf6..1f68f2272b 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ "sphinx-multiversion==0.2.4", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers~=4.42.4", + "transformers~=4.43.3", ] From e9e3848e63eca534d4d30600a7e93e6faee7b04d Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 27 Jul 2024 15:51:49 +0200 Subject: [PATCH 2/8] Re-copy Llama & Beit --- src/adapters/models/beit/modeling_beit.py | 11 +++- src/adapters/models/llama/modeling_llama.py | 65 ++++++++++++++++----- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/src/adapters/models/beit/modeling_beit.py b/src/adapters/models/beit/modeling_beit.py index 6e56d2b864..865fcdeae5 100644 --- a/src/adapters/models/beit/modeling_beit.py +++ b/src/adapters/models/beit/modeling_beit.py @@ -35,6 +35,7 @@ def forward( output_attentions: bool = False, relative_position_bias: Optional["BeitRelativePositionBias"] = None, interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: mixed_query_layer = self.query(hidden_states) @@ -51,9 +52,11 @@ def forward( # Add relative position bias if present. if self.relative_position_bias is not None: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) attention_scores = attention_scores + self.relative_position_bias( - interpolate_pos_encoding, attention_scores.shape[2] - ).unsqueeze(0) + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) # Add shared relative position bias if provided. if relative_position_bias is not None: @@ -89,8 +92,9 @@ def forward( hidden_states: torch.Tensor, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - relative_position_bias: Optional[BeitRelativePositionBias] = None, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention @@ -98,6 +102,7 @@ def forward( output_attentions=output_attentions, relative_position_bias=relative_position_bias, interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 2987cb3b45..33bdb28c3c 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -57,6 +58,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -94,8 +96,16 @@ def forward( (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - past_key_value = getattr(self, "past_key_value", past_key_value) - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -133,7 +143,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, -1) if self.config.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) @@ -158,7 +168,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -188,7 +198,16 @@ def forward( (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -220,7 +239,7 @@ def forward( # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype - if input_dtype == torch.float32 or key_states.dtype == torch.float32: + if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized @@ -239,11 +258,19 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -264,6 +291,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -279,6 +308,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() @@ -298,7 +328,16 @@ def forward( (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -329,8 +368,8 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an - # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -343,7 +382,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) From e8b03aa8b986e594d9bdc73ec305822b08ff519b Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 27 Jul 2024 18:59:06 +0200 Subject: [PATCH 3/8] Add Clip Sdp/ flash attn --- src/adapters/models/clip/modeling_clip.py | 167 +++++++++++++++++++++- 1 file changed, 166 insertions(+), 1 deletion(-) diff --git a/src/adapters/models/clip/modeling_clip.py b/src/adapters/models/clip/modeling_clip.py index fecbb105c8..7328e532c4 100644 --- a/src/adapters/models/clip/modeling_clip.py +++ b/src/adapters/models/clip/modeling_clip.py @@ -21,11 +21,25 @@ import torch.utils.checkpoint from torch import nn -from transformers.models.clip.modeling_clip import CLIPAttention, CLIPEncoderLayer +from transformers.models.clip.modeling_clip import ( + CLIPAttention, + CLIPEncoderLayer, + CLIPFlashAttention2, + CLIPSdpaAttention, +) +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2 +from transformers.utils import is_flash_attn_2_available, logging + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward from .mixin_clip import CLIPAttentionAdaptersMixin, CLIPEncoderLayerAdaptersMixin +logger = logging.get_logger(__name__) + + class CLIPAttentionWithAdapters(CLIPAttentionAdaptersMixin, CLIPAttention): def forward( self, @@ -46,9 +60,11 @@ def forward( proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + # >>> START AH Changes <<< key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask ) + # >>> END AH Changes <<< key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) @@ -115,6 +131,155 @@ def forward( return attn_output, attn_weights_reshaped +class CLIPFlashAttention2WithAdapters(CLIPAttentionAdaptersMixin, CLIPFlashAttention2): + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # >>> START AH Changes <<< + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + # >>> END AH Changes <<< + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=causal_attention_mask is not None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class CLIPSdpaAttentionWithAdapters(CLIPAttentionAdaptersMixin, CLIPSdpaAttention): + # Adapted from CLIPAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " + "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " + 'be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + + # CLIP text model uses both `causal_attention_mask` and `attention_mask` + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + bsz, tgt_len, embed_dim = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + # >>> START AH Changes <<< + key_states, value_states, attn_mask = self.prefix_tuning(key_states, value_states, hidden_states, attn_mask) + # >>> END AH Changes <<< + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + scale=self.scale, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + class CLIPEncoderLayerWithAdapters(CLIPEncoderLayerAdaptersMixin, CLIPEncoderLayer): def forward( self, From 7942eb27f2e896be2781f208a260353c48b83eb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Thu, 1 Aug 2024 23:20:15 +0200 Subject: [PATCH 4/8] Fix our tie_weights method. --- src/adapters/heads/model_mixin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index 11a194ef92..9a27bbd764 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -134,6 +134,8 @@ def tie_weights(self): self = getattr(self, self.base_model_prefix) self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + super().tie_weights() + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): old_embeddings = self.get_input_embeddings() new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) From 34c1d54695649adf9556884bd341e839550dceae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Fri, 2 Aug 2024 00:12:21 +0200 Subject: [PATCH 5/8] increase minimum pytorch version to have & torch.nn.attention and support FlashAttention2 - overall speedup - fixing failing test: Hugging Face needs to have has this import that resulted prior to this fix in an error: `from torch.nn.attention import SDPBackend, sdpa_kernel` --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1f68f2272b..9cd8fd051c 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ "sphinx-intl==2.1.0", "sphinx-multiversion==0.2.4", "timeout-decorator", - "torch>=1.10,!=1.12.0", + "torch>=2.3", "transformers~=4.43.3", ] From 7b073ca1b5bf6bccd4b925848b02595b7ec82823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Fri, 2 Aug 2024 00:21:09 +0200 Subject: [PATCH 6/8] increasing pytorch version in github action --- .github/workflows/tests_torch.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index 668beb9e62..f7a394ce4a 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -39,7 +39,7 @@ jobs: key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[quality] - name: Check Quality and Repo Consistency run: | @@ -62,7 +62,7 @@ jobs: ${{ runner.os }}-pip- - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[sklearn,testing,sentencepiece] - name: Test run: | @@ -85,7 +85,7 @@ jobs: ${{ runner.os }}-pip- - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[sklearn,testing,sentencepiece] - name: Test run: | @@ -108,7 +108,7 @@ jobs: ${{ runner.os }}-pip- - name: Install run: | - pip install torch==2.1.2 + pip install torch==2.3 pip install .[sklearn,testing,sentencepiece] pip install conllu seqeval - name: Test Examples From 3afdbc5cf835215017ffe57ec792e1f6ae7d1be5 Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 2 Aug 2024 21:26:59 +0200 Subject: [PATCH 7/8] Test precision --- tests/methods/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/methods/base.py b/tests/methods/base.py index 09dfbed75b..615f00cedd 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -243,7 +243,7 @@ def run_full_model_load_test(self, adapter_config): output1 = model1(**input_data) output2 = model2(**input_data) self.assertEqual(len(output1), len(output2)) - self.assertTrue(torch.equal(output1[0], output2[0])) + self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4)) def trainings_run(self, model, lr=1.0, steps=8): # setup dataset From 632b786b50afc393ba2600600ac68fcfb57083a8 Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 2 Aug 2024 22:56:16 +0200 Subject: [PATCH 8/8] unpin torch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9cd8fd051c..39d994e999 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ "sphinx-intl==2.1.0", "sphinx-multiversion==0.2.4", "timeout-decorator", - "torch>=2.3", + "torch", "transformers~=4.43.3", ]