Skip to content

Commit

Permalink
Update modular_starcoder2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 16, 2024
1 parent 5e56d9c commit 598b7bb
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/transformers/models/starcoder2/modular_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,12 @@ def forward(
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

dropout_rate = 0.0 if not self.training else self.attention_dropout

attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
dropout=dropout_rate,
dropout=0.0 if not self.training else self.attention_dropout,
**kwargs,
)

Expand Down Expand Up @@ -283,7 +281,7 @@ class Starcoder2ForTokenClassification(LlamaForTokenClassification):
__all__ = [
"Starcoder2ForCausalLM",
"Starcoder2Model",
"Starcoder2PreTrainedModel",
"Starcoder2PreTrainedModel", # noqa: F822
"Starcoder2ForSequenceClassification",
"Starcoder2ForTokenClassification",
]

0 comments on commit 598b7bb

Please sign in to comment.