-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathffpa_attn_F16F16F16_L1.cu
38 lines (35 loc) · 1.14 KB
/
ffpa_attn_F16F16F16_L1.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include "launch_templates.cuh"
using namespace ffpa;
void ffpa_mma_acc_f16_L1(torch::Tensor Q,
torch::Tensor K,
torch::Tensor V,
torch::Tensor O,
int stages) {
CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
const int d = Q.size(3); // B, H, N, d
// Q@K^T or P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
constexpr int kMmaAccFloat32QK = 0;
constexpr int kMmaAccFloat32PV = 0;
#ifdef ENABLE_FFPA_ALL_STAGES
// dispatch stages
if (stages == 2) {
DISPATCH_HEADDIM(LAUNCHER_L1, 2);
} else if (stages == 3) {
DISPATCH_HEADDIM(LAUNCHER_L1, 3);
} else if (stages == 4) {
DISPATCH_HEADDIM(LAUNCHER_L1, 4);
} else {
DISPATCH_HEADDIM(LAUNCHER_L1, 1);
}
#else
// dispatch stages
if (stages == 2) {
DISPATCH_HEADDIM(LAUNCHER_L1, 2);
} else {
DISPATCH_HEADDIM(LAUNCHER_L1, 1);
}
#endif
}