diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 9ca1d9a182e7..aea9f4a90250 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -814,7 +814,7 @@ def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tenso # matrix multiplication # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim - # bcxy: batch_size * num_heads x chunks x 2window_overlap x window_overlap + # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply # convert diagonals into columns