Skip to content

Commit

Permalink
remove all copied from attentions
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 16, 2024
1 parent 0f565fb commit c9ac84d
Show file tree
Hide file tree
Showing 41 changed files with 337 additions and 51 deletions.
2 changes: 1 addition & 1 deletion src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _flash_attention_forward(
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
causal = is_causal and query_length != 1

# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ class BarkSelfFlashAttention2(BarkSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ class BartFlashAttention2(BartAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def forward(
return attn_output, attn_weights, past_key_value


# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
# TODO(joao): add me back asap :)
class ChameleonFlashAttention2(ChameleonAttention):
"""
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ class CLIPFlashAttention2(CLIPAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
# TODO cyril: modular
class CohereFlashAttention2(CohereAttention):
"""
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ class DbrxFlashAttention2(DbrxAttention):
calls the public API of flash attention.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention):
API of flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,6 @@ class FalconFlashAttention2(FalconAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
API of flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ class GPTJFlashAttention2(GPTJAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,8 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
# TODO cyril: modular
class GraniteMoeFlashAttention2(GraniteMoeAttention):
"""
GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays
Expand Down Expand Up @@ -617,7 +618,8 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe
# TODO cyril: modular
class GraniteMoeSdpaAttention(GraniteMoeAttention):
"""
GraniteMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ class HubertFlashAttention2(HubertAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -859,15 +858,15 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
# TODO cyril: modular
class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
"""
Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ class JambaFlashAttention2(JambaAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,6 @@ def forward(


class JetMoeFlashAttention2(JetMoeAttention):
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/m2m_100/modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,6 @@ class M2M100FlashAttention2(M2M100Attention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ class MBartFlashAttention2(MBartAttention):
flash attention and deal with padding tokens in case the input contains any of them.
"""

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
# TODO cyril: modular
class MimiFlashAttention2(MimiAttention):
"""
Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays
Expand Down Expand Up @@ -670,7 +671,8 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
# TODO cyril: modular
class MimiSdpaAttention(MimiAttention):
"""
Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
Expand Down
Loading

0 comments on commit c9ac84d

Please sign in to comment.