-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for is_causal in attention
- Loading branch information
Showing
5 changed files
with
168 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestScaledDotProductAttention(DispatchTestCase): | ||
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) | ||
def test_sdpa_no_causal(self, query_shape, key_shape): | ||
class SDPA(nn.Module): | ||
def forward(self, query, key, value): | ||
return torch.nn.functional.scaled_dot_product_attention( | ||
query, key, value, None, 0.0, False, scale=None | ||
) | ||
|
||
inputs = [] | ||
query = torch.randn(query_shape, dtype=torch.float16) | ||
key = torch.rand(key_shape, dtype=torch.float16) | ||
value = torch.rand(key_shape, dtype=torch.float16) | ||
inputs.extend([query, key, value]) | ||
self.run_test(SDPA(), inputs, precision=torch.float16) | ||
|
||
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) | ||
def test_sdpa_causal(self, query_shape, key_shape): | ||
class SDPA(nn.Module): | ||
def forward(self, query, key, value): | ||
return torch.nn.functional.scaled_dot_product_attention( | ||
query, key, value, None, 0.0, True, scale=None | ||
) | ||
|
||
inputs = [] | ||
query = torch.randn(query_shape, dtype=torch.float16) | ||
key = torch.rand(key_shape, dtype=torch.float16) | ||
value = torch.rand(key_shape, dtype=torch.float16) | ||
inputs.extend([query, key, value]) | ||
self.run_test(SDPA(), inputs, precision=torch.float16) | ||
|
||
|
||
@unittest.skipIf( | ||
torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, | ||
"GPU compute capability is too low to run flash attention, need Ampere (8.0) or greater", | ||
) | ||
class TestFlashAttention(DispatchTestCase): | ||
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) | ||
def test_sdpa_causal(self, query_shape, key_shape): | ||
class SDPA(nn.Module): | ||
def forward(self, query, key, value): | ||
attn = torch.ops.aten._scaled_dot_product_flash_attention.default( | ||
query, | ||
key, | ||
value, | ||
0, | ||
True, # is_causal | ||
False, | ||
scale=0.25, | ||
) | ||
return attn[0] | ||
|
||
inputs = [] | ||
query = torch.randn(query_shape, dtype=torch.float16) | ||
key = torch.rand(key_shape, dtype=torch.float16) | ||
value = torch.rand(key_shape, dtype=torch.float16) | ||
inputs.extend([query, key, value]) | ||
self.run_test( | ||
SDPA(), | ||
inputs, | ||
precision=torch.float16, | ||
enable_passes=True, | ||
) | ||
|
||
|
||
class TestEfficientAttention(DispatchTestCase): | ||
@parameterized.expand([((32, 8, 128, 64), (32, 8, 128, 64))]) | ||
def test_sdpa_causal(self, query_shape, key_shape): | ||
class SDPA(nn.Module): | ||
def forward(self, query, key, value): | ||
attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( | ||
query, | ||
key, | ||
value, | ||
None, | ||
False, | ||
0, | ||
True, # is_causal | ||
scale=0.5, | ||
) | ||
return attn[0] | ||
|
||
inputs = [] | ||
query = torch.randn(query_shape, dtype=torch.float16) | ||
key = torch.rand(key_shape, dtype=torch.float16) | ||
value = torch.rand(key_shape, dtype=torch.float16) | ||
inputs.extend([query, key, value]) | ||
self.run_test( | ||
SDPA(), | ||
inputs, | ||
precision=torch.float16, | ||
enable_passes=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters