-
Notifications
You must be signed in to change notification settings - Fork 28.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix cross-attention head mask for Torch encoder-decoder models (#10605)
* Fix cross-attention head mask for Torch BART models * Fix head masking for cross-attention module for the following models: BART, Blenderbot, Blenderbot_small, M2M_100, Marian, MBart, Pegasus * Enable test_headmasking for M2M_100 model * Fix cross_head_mask for FSMT, LED and T5 * This commit fixes `head_mask` for cross-attention modules in the following models: FSMT, LED, T5 * It also contains some smaller changes in doc so that it is be perfectly clear the shape of `cross_head_mask` is the same as of `decoder_head_mask` * Update template * Fix template for BartForCausalLM * Fix cross_head_mask for Speech2Text models * Fix cross_head_mask in templates * Fix args order in BartForCausalLM template * Fix doc in BART templates * Make more explicit naming * `cross_head_mask` -> `cross_attn_head_mask` * `cross_layer_head_mask` -> `cross_attn_layer_head_mask` * Fix doc * make style quality * Fix speech2text docstring
- Loading branch information
Showing
23 changed files
with
587 additions
and
389 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.