diff --git a/Makefile b/Makefile index ef152d2467ed5..dc3de3cb14e44 100644 --- a/Makefile +++ b/Makefile @@ -596,7 +596,7 @@ ifdef GGML_RPC OBJ_GGML_EXT += ggml/src/ggml-rpc.o endif # GGML_RPC -OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu)) +OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu)) OBJ_CUDA_TMPL += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu)) ifdef GGML_CUDA_FA_ALL_QUANTS diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1198dc1fd9378..5bd8d9c8b5023 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1775,7 +1775,7 @@ extern "C" { struct ggml_tensor * a, int k); -#define GGML_KQ_MASK_PAD 32 +#define GGML_KQ_MASK_PAD 64 // q: [n_embd, n_batch, n_head, 1] // k: [n_embd, n_kv, n_head_kv, 1] diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 14761650f8a8e..119fd39b8e437 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") - file(GLOB SRCS "template-instances/fattn-wmma*.cu") + file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8d8d3932e0e58..88be8fc8a6ec6 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -148,7 +148,7 @@ typedef float2 dfloat2; #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#define INT8_MMA_AVAILABLE +#define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) @@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) { return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } +// Any FP16 tensor cores are available. static constexpr bool fp16_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; } -static constexpr bool int8_mma_available(const int cc) { +// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static constexpr bool new_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index ee9752da6a845..cfd7c0f4475dc 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +template // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / KQ_stride; + const int iter_j = (ne01 + (ncols - 1)) / ncols; + + const int bidx0 = blockIdx.x; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + dst += jt*ncols*ne02*D + channel*D; + + // Load the partial result that needs a fixup: + float dst_val[ncols] = {0.0f}; + float max_val[ncols] = {0.0f}; + float rowsum[ncols] = {0.0f}; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + break; + } + dst_val[j] = dst[j*ne02*D + threadIdx.x]; + + const float2 tmp = dst_fixup[bidx0*ncols + j]; + max_val[j] = tmp.x; + rowsum[j] = tmp.y; + } + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + break; + } + const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val[j], tmp.x); + + const float diff_val = max_val[j] - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; + rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y; + + max_val[j] = max_val_new; + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (jt*ncols + j >= ne01) { + return; + } + dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; + } +} + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) { } } -template +// parallel_blocks == 0 is stream-k decomposition +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, - const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V + const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V ) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -603,20 +702,23 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + GGML_ASSERT(Q->ne[3] == 1); + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); - char * K_data = (char *) K->data; + const char * K_data = (const char *) K->data; size_t nb11 = K->nb[1]; size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - char * V_data = (char *) V->data; + const char * V_data = (const char *) V->data; size_t nb21 = V->nb[1]; size_t nb22 = V->nb[2]; size_t nb23 = V->nb[3]; @@ -649,39 +751,60 @@ void launch_fattn( nb23 = nb23*bs*sizeof(half)/ts; } - if (parallel_blocks > 1) { - dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); - dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); - } + const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); + const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; const dim3 block_dim(WARP_SIZE, nwarps, 1); - const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); - const int shmem = 0; + dim3 blocks_num; + if (parallel_blocks == 0) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; + const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; + const bool short_context = K->ne[1] < 4096; + + const int nblocks_stream_k = 2*nsm; + + blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); + } else { + blocks_num.x = parallel_blocks*ntiles_x; + blocks_num.y = Q->ne[2]; + blocks_num.z = Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); if (logit_softcap != 0.0f) { scale /= logit_softcap; } const uint32_t n_head = Q->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - fattn_kernel<<>>( + fattn_kernel<<>>( (const char *) Q->data, K_data, V_data, mask ? ((const char *) mask->data) : nullptr, - (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], @@ -693,16 +816,22 @@ void launch_fattn( ); CUDA_CHECK(cudaGetLastError()); - if ((parallel_blocks) == 1) { - return; - } + if constexpr (parallel_blocks == 0) { + if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = blocks_num; - const dim3 block_dim_combine(D, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - const int shmem_combine = 0; + flash_attn_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if constexpr (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 0000000000000..05bc91a3b8d6f --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,637 @@ +#include "common.cuh" +#include "mma.cuh" +#include "fattn-common.cuh" + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef NEW_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C_KQ; + typedef mma_C_I16J8 mma_C_VKQ; + + static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); + constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. + + const int stride_Q = nb01 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + mma_B Q_B[D/(2*mma_B::K)]; + mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; + + float2 KQ_rowsum = {0.0f, 0.0f}; + float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; + float2 KQ_max_scale = {0.0f, 0.0f}; + + // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { + const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jt*ncols + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; + tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + { + const int j0 = (threadIdx.y / np) * mma_B::J; + +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { + Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { + const int k_VKQ_0 = kb0*KQ_stride; + mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; + + // Load K data into tile with decreasing granularity for D for better memory bandwidth: + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { + const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { + const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; + } + } + } + + __syncthreads(); + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { + mma_A K_A; + K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); + } + } + + __syncthreads(); + + if (use_logit_softcap) { + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + if (maskh) { + static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); + static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); +#pragma unroll + for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + const int i = i0 + mma_C_KQ::get_i(l); + const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); + + KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + float2 KQ_max_new = KQ_max; + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { +#pragma unroll + for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { + KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); + KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); + KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); + } + + { + const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); + KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); + if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale.x = 0.0f; + } + if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale.y = 0.0f; + } + KQ_max = KQ_max_new; + } + + float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); + static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < mma_C_KQ::ne; ++l) { + const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; + const float diff = KQ_C[k].x[l] - KQ_max_l; + KQ_C[k].x[l] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_C[k].x[l] = 0.0f; + } + + if (l % 2 == 0) { + KQ_rowsum_add.x += KQ_C[k].x[l]; + } else { + KQ_rowsum_add.y += KQ_C[k].x[l]; + } + } + } + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; + KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; + + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); +#pragma unroll + for (int i = 0; i < D/mma_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < mma_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + mma_B B[KQ_stride/(np*2*mma_B::K)]; + static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { + B[k] = KQ_C[k].to_mma_B(); + } + + // Load V data into tile with decreasing granularity for D for better memory bandwidth: + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); + const int i0_stop = D/2 - (D/2) % (1*stride_i); + const int stride_k = WARP_SIZE / stride_i; + +#pragma unroll + for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { + const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); + +#pragma unroll + for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { + const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); + + tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; + } + } + } + + __syncthreads(); + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { + static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { + const int k0 = k00 + (threadIdx.y % np)*mma_A::K; + + mma_A A; + A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); + } + } + + __syncthreads(); + } + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8 threads each, does not need full reduce. +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); + KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); + } + + // Write VKQ accumulators to shared memory in column-major format. + // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // Also for np > 1 the combination is done via these values in shared memory. + const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { + const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < mma_B::ne; ++l) { + const int k = k0 + mma_B::get_k(l); + + tile_KV[j_cwd*D2_padded + k] = B.x[l]; + } + } + + const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset + const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta + const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + static_assert(np == 1 || np == 2 || np == 4, "bad np"); + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[j_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[j_cwm] = KQ_cmr; + } + } else if (threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; + + float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + KQ_cm = meta_j[0]; + } + + float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. + float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + KQ_crs = KQ_cms*meta_j[1]; + } +#pragma unroll + for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: + if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + meta_j[0] = KQ_cmn; // Combined max. KQ values. + meta_j[1] = KQ_crs; // Combined KQ rowsums. + meta_j[2] = KQ_cms; // KQ max scales per parallel warp. + } + if (needs_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && threadIdx.x < mma_B::J) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + + if (np > 1) { + __syncthreads(); + } + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_j = WARP_SIZE / stride_k; + + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { + break; + } + +#pragma unroll + for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { + const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; + + if (!is_fixup && jt*ncols + j_dst >= ne01) { + continue; + } + const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; + const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; + } else { + dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; + } + } + } + } + } + + if (np > 1) { + __syncthreads(); + } +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE +} + +template +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 2) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int iter_k = ne11 / KQ_stride; + const int iter_j = (ne01 + (ncols - 1)) / ncols; + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape + const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(D/2); + + const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape + const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(D/2); + + const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, + ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, + jt, kb0_start, kb0_stop); +} + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + + static_assert(D % mma_B::K == 0, "bad D"); + static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); + + const ggml_tensor * KQV = dst; + + constexpr int KQ_stride = D <= 128 ? 64 : 32; + constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? + cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); + constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); +} + +#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_MMA_F16_CASE( 64, 8); +extern DECL_FATTN_MMA_F16_CASE( 80, 8); +extern DECL_FATTN_MMA_F16_CASE( 96, 8); +extern DECL_FATTN_MMA_F16_CASE(112, 8); +extern DECL_FATTN_MMA_F16_CASE(128, 8); +extern DECL_FATTN_MMA_F16_CASE(256, 8); + +extern DECL_FATTN_MMA_F16_CASE( 64, 16); +extern DECL_FATTN_MMA_F16_CASE( 80, 16); +extern DECL_FATTN_MMA_F16_CASE( 96, 16); +extern DECL_FATTN_MMA_F16_CASE(112, 16); +extern DECL_FATTN_MMA_F16_CASE(128, 16); +extern DECL_FATTN_MMA_F16_CASE(256, 16); + +extern DECL_FATTN_MMA_F16_CASE( 64, 32); +extern DECL_FATTN_MMA_F16_CASE( 80, 32); +extern DECL_FATTN_MMA_F16_CASE( 96, 32); +extern DECL_FATTN_MMA_F16_CASE(112, 32); +extern DECL_FATTN_MMA_F16_CASE(128, 32); +extern DECL_FATTN_MMA_F16_CASE(256, 32); + +extern DECL_FATTN_MMA_F16_CASE( 64, 64); +extern DECL_FATTN_MMA_F16_CASE( 80, 64); +extern DECL_FATTN_MMA_F16_CASE( 96, 64); +extern DECL_FATTN_MMA_F16_CASE(112, 64); +extern DECL_FATTN_MMA_F16_CASE(128, 64); +extern DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 4d314dacb1958..d4edbad07f26a 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; @@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; + constexpr int D = 64; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; + constexpr int D = 128; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index bb33604470fdd..0d274f33255b7 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32( NO_DEVICE_CODE; return; #endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; @@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; + constexpr int D = 64; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; + constexpr int D = 128; + constexpr int nwarps = 8; + constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 34a2992c769b9..d9ac4424606e4 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index a28fc8b7fc893..6ef8f9dcc2756 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu new file mode 100644 index 0000000000000..1054ff95d2296 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -0,0 +1,648 @@ +// Old and deprecated WMMA FlashAttention implementation. +// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing. +// Long-term the WMMA code should be replaced with a dedicated Volta implementation. + +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +#ifdef FP16_MMA_AVAILABLE +#include +#endif // FP16_MMA_AVAILABLE + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + const half2 slope2 = make_half2(slopef, slopef); + + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + } + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template +void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + constexpr int nwarps = 4; + + constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (4*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 4; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 2; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); + return; + } + constexpr int parallel_blocks = 1; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true); +} + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + + if (prec != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + } else { + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + // case 256: + // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + // break; + default: + GGML_ABORT("fatal error"); + break; + } + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 860d0e6dc2fa4..beeea95eb1d62 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,543 +1,3 @@ #include "common.cuh" -#include "fattn-common.cuh" -#ifdef FP16_MMA_AVAILABLE -#include -#endif // FP16_MMA_AVAILABLE - -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#ifdef FP16_MMA_AVAILABLE - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); - static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); - constexpr int frag_m = ncols == 8 ? 32 : 16; - constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; - - constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. - constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. - static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); - - // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; - constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; - constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); - - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - const half2 slope2 = make_half2(slopef, slopef); - - const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); - - frag_b Q_b[D/16][ncols/frag_n]; - - // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; - __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; - float * KQ_f = (float *) KQ; - half2 * KQ2 = (half2 *) KQ; - - float KQ_rowsum_f[ncols/nwarps] = {0.0f}; - float KQ_max_f[ncols/nwarps]; - float KQ_max_scale_f[ncols/nwarps] = {0.0f}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_f[j] = -FLT_MAX/2.0f; - } - - half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max_h2[ncols/nwarps]; - half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); - } - - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. - half2 * VKQ2 = (half2 *) VKQ; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); - } - } - - // Convert Q to half and apply scale, temporarily store in KQ: -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } - } - - __syncthreads(); - - // Load Q into tensor core fragments/registers since it will be used frequently: -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); - } - } - - __syncthreads(); - - // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c_KQ KQ_c[ncols/frag_n]; -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); - } -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; - - if (use_logit_softcap) { - KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); - } - } - - float KQ_max_new = KQ_max_f[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); - } - KQ_max_new = warp_reduce_max(KQ_max_new); - - const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; - KQ_max_scale_f[j0/nwarps] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale_f[j0/nwarps] = 0.0f; - } - KQ_max_f[j0/nwarps] = KQ_max_new; - - float KQ_rowsum_add = 0.0f; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; - } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; - } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - - if (use_logit_softcap) { - // There is no dedicated tangens hyperbolicus function for half2. - KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); - KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) - /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); - - KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; - } - } - - half2 KQ_max_new = KQ_max_h2[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; - KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; - KQ_max_h2[j0/nwarps] = KQ_max_new; - - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; - } - } - - __syncthreads(); - - frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( - KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, - kqar*kqs_padded); - } - } - - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); - } - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - - frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); - } - } - } - - __syncthreads(); - - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), - VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - half2 VKQ_scale; - if (std::is_same::value) { - VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); - } else { - VKQ_scale = KQ_max_scale_h2[j0/nwarps]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - - half2 VKQ_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; - } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; - } - } - - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { - return; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - - float KQ_rowsum_j; - if (std::is_same::value) { - KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; - } else { - KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); - } - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - float dst_val = VKQ[j_VKQ*D_padded + i]; - if (parallel_blocks == 1) { - dst_val /= KQ_rowsum_j; - } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; - } - - if (parallel_blocks == 1 || threadIdx.x != 0) { - continue; - } - - float2 dst_meta_val; - if (std::is_same::value) { - dst_meta_val.x = KQ_max_f[j0/nwarps]; - } else { - dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); - } - dst_meta_val.y = KQ_rowsum_j; - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; - } -#else - NO_DEVICE_CODE; -#endif // FP16_MMA_AVAILABLE -} - -constexpr int get_max_power_of_2(int x) { - return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; -} - -static_assert(get_max_power_of_2(1) == 1, "Test failed."); -static_assert(get_max_power_of_2(2) == 2, "Test failed."); -static_assert(get_max_power_of_2(4) == 4, "Test failed."); -static_assert(get_max_power_of_2(6) == 2, "Test failed."); - -// Number of VKQ rows calculated in parallel: -constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { - return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; -} - -static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); -static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); - -template -void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - constexpr int nwarps = 4; - - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; - const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (4*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); - return; - } - if (2*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); - return; - } - constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } else { - constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16< - D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; - } - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); -} - -#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ - template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); -extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); -extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(112, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(128, 16, float); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float); -extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float); -extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float); -extern DECL_FATTN_WMMA_F16_CASE(112, 32, float); -extern DECL_FATTN_WMMA_F16_CASE(128, 32, float); -// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 8, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 8, half); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half); -extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(112, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 16, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); - -extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half); -extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half); -extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); -extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0b26b0f8e0595..b1e66d470832c 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,5 +1,6 @@ #include "common.cuh" #include "fattn-common.cuh" +#include "fattn-mma-f16.cuh" #include "fattn-tile-f16.cuh" #include "fattn-tile-f32.cuh" #include "fattn-vec-f16.cuh" @@ -7,144 +8,56 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -#include +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; -static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - - if (prec != GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 32 || Q->ne[0] > 128) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - } else { - constexpr int cols_per_block = 32; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - break; - // case 256: - // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - // break; - default: - GGML_ABORT("fatal error"); - break; - } - } - return; - } - - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { - constexpr int cols_per_block = 8; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } - return; - } - - constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); break; default: GGML_ABORT("fatal error"); break; } } + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (Q->ne[1] <= 16) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); + return; + } + + if (Q->ne[1] <= 32) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); +} + #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ @@ -322,11 +235,19 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - if (!fp16_mma_available(cc)) { - if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + if (!new_mma_available(cc)) { + if (prec == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + } } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + } } return; } @@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } } - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: + if (cc == GGML_CUDA_CC_VOLTA) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + } + + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 7d11540afd237..fbfd15527ddb6 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1,11 +1,29 @@ +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape I x K. +// B is a column-major matrix with shape K x J. +// C is a column-major matrix with shape I x J. +// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. +// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As desribted in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. + #include "common.cuh" -struct mma_int_A_I16K4 { +template +struct mma_A_I16K4 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int I = 16; static constexpr int K = 4; static constexpr int ne = 2; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_i(const int l) { const int ret = (l%2) * (I/2) + threadIdx.x / K; @@ -21,27 +39,35 @@ struct mma_int_A_I16K4 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(x[0]), "+r"(x[1]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_i(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(xi[0]), "+r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_A_I16K8 { +template +struct mma_A_I16K8 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int I = 16; static constexpr int K = 8; static constexpr int ne = 4; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_i(const int l) { const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); @@ -57,31 +83,69 @@ struct mma_int_A_I16K8 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) - const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_i(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) } - __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { - ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) x; + const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ void transpose() { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(xi[0]) : "r"(xi[0])); + int tmp = 0; + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(tmp) : "r"(xi[1])); + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(xi[1]) : "r"(xi[2])); + xi[2] = tmp; + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(xi[3]) : "r"(xi[3])); +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_B_J8K4 { +template +struct mma_B_J8K4 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int J = 8; static constexpr int K = 4; static constexpr int ne = 1; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_j(const int /* l */) { const int ret = threadIdx.x / K; @@ -97,27 +161,34 @@ struct mma_int_B_J8K4 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster - const int * xs = xs0 + (threadIdx.x%J)*stride; - asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" - : "+r"(x[0]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_j(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; + asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" + : "+r"(xi[0]) : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_B_J8K8 { +template +struct mma_B_J8K8 { + static_assert(sizeof(T) == 4, "bad type size"); + static constexpr int J = 8; static constexpr int K = 8; static constexpr int ne = 2; - int x[ne] = {0}; + T x[ne]; static __device__ __forceinline__ int get_j(const int /* l */) { const int ret = threadIdx.x / (K/2); @@ -133,22 +204,31 @@ struct mma_int_B_J8K8 { return ret; } - __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { -#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster - const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(x[0]), "+r"(x[1]) - : "l"(xs)); -#else + __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { #pragma unroll for (int l = 0; l < ne; ++l) { x[l] = xs0[get_j(l)*stride + get_k(l)]; } -#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(xi[0]), "+r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); +#endif // NEW_MMA_AVAILABLE } }; -struct mma_int_C_I16J8 { +template +struct mma_C_I16J8 {}; + +template <> +struct mma_C_I16J8 { static constexpr int I = 16; static constexpr int J = 8; static constexpr int ne = 4; @@ -169,8 +249,8 @@ struct mma_int_C_I16J8 { return ret; } - __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { -#ifdef INT8_MMA_AVAILABLE + __device__ __forceinline__ void mma(const mma_A_I16K4 & mma_A, const mma_B_J8K4 & mma_B) { +#ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) @@ -188,11 +268,11 @@ struct mma_int_C_I16J8 { GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { -#ifdef INT8_MMA_AVAILABLE + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) @@ -216,6 +296,144 @@ struct mma_int_C_I16J8 { GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE + } +}; + +template <> +struct mma_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 4; + static constexpr int ne = 2; + + half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = l * (I/2) + threadIdx.x / J; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x % J; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE + int * Axi = (int *) mma_A.x; + int * Bxi = (int *) mma_B.x; + int * xi = (int *) x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(xi[0]), "+r"(xi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ mma_B_J8K8 to_mma_B() { + mma_B_J8K8 mma_B; + +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) x; + int * Bxi = (int *) mma_B.x; + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(Bxi[0]) : "r"(xi[0])); + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "+r"(Bxi[1]) : "r"(xi[1])); +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + + return mma_B; + } +}; + +template <> +struct mma_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 8; + static constexpr int ne = 4; + + float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int l) { + const int ret = 2 * (threadIdx.x % (J/2)) + l%2; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { +#ifdef NEW_MMA_AVAILABLE + int * Axi = (int *) mma_A.x; + int * Bxi = (int *) mma_B.x; + int * xi = (int *) x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + } + + __device__ __forceinline__ mma_B_J8K8 to_mma_B() { + mma_B_J8K8 mma_B; + mma_B.x[0] = make_half2(x[0], x[1]); + mma_B.x[1] = make_half2(x[2], x[3]); + +#ifdef NEW_MMA_AVAILABLE + int * Bxi = (int *) mma_B.x; + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;" + : "+r"(Bxi[0]) : ); + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %0;" + : "+r"(Bxi[1]) : ); +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE + + return mma_B; + } + + __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_i(l)]; + } } }; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 270251df4f115..83cb78cbdd451 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (int8_mma_available(cc)) { + if (new_mma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3cd508a1d4a1c..c05c8477812b1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -87,7 +87,7 @@ struct tile_x_sizes { }; static constexpr int get_mmq_x_max_host(const int cc) { - return int8_mma_available(cc) ? 128 : + return new_mma_available(cc) ? 128 : #ifdef GGML_CUDA_FORCE_MMQ cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64; #else @@ -96,9 +96,9 @@ static constexpr int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE return 128; -#else // INT8_MMA_AVAILABLE +#else // NEW_MMA_AVAILABLE #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return 128; @@ -116,7 +116,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } static constexpr int get_mmq_y_host(const int cc) { @@ -209,10 +209,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; + return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8; } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } @@ -220,21 +220,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { return 8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; @@ -250,12 +250,12 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -271,11 +271,11 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -322,14 +322,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -345,12 +345,12 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -366,11 +366,11 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -417,14 +417,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -456,13 +456,13 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; @@ -478,25 +478,25 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -526,13 +526,13 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -548,25 +548,25 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; @@ -581,13 +581,13 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); #else x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; @@ -603,11 +603,11 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -645,9 +645,9 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -672,7 +672,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { const int k0 = k00 + k01; - A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll @@ -695,7 +695,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( mma_B B; float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -711,7 +711,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_0], B); + C.mma(A[n][k01/QI8_0], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -756,9 +756,9 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K8 mma_A; + typedef mma_B_J8K8 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -782,7 +782,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll @@ -805,7 +805,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( mma_B B; float2 dsB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -817,7 +817,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C; - C.mma_K8(A[n][k01/QI8_1], B); + C.mma(A[n][k01/QI8_1], B); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -864,12 +864,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_A_I16K8 mma_A_K8; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -893,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll @@ -916,8 +916,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( mma_B B[2]; float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -929,8 +930,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); + C[0].mma(A[n][k01/4 + 0], B[0]); + C[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -942,20 +943,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI2_K; @@ -977,11 +978,11 @@ template static __device__ __forceinlin const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int sc_m = bxi->scales[kqsx]; @@ -992,11 +993,11 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1051,12 +1052,12 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_A_I16K8 mma_A_K8; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -1081,7 +1082,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } @@ -1118,24 +1119,25 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { mma_B B[2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); mma_C Cm[2]; if (k01 >= WARP_SIZE * 3/4) { mma_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; - Cm[0].mma_K4(A1, B[0]); - Cm[1].mma_K4(A1, B[1]); + Cm[0].mma(A1, B[0]); + Cm[1].mma(A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C Cd[2]; - Cd[0].mma_K4(A[n][k01/4 + 0], B[0]); - Cd[1].mma_K4(A[n][k01/4 + 1], B[1]); + Cd[0].mma(A[n][k01/4 + 0], B[0]); + Cd[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -1172,13 +1174,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else @@ -1186,7 +1188,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI3_K; @@ -1212,11 +1214,11 @@ template static __device__ __forceinlin const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1242,7 +1244,7 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1252,10 +1254,10 @@ template static __device__ __forceinlin } #else x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifndef INT8_MMA_AVAILABLE +#ifndef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; @@ -1268,7 +1270,7 @@ template static __device__ __forceinlin x_df[i] = bxi->d; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1317,7 +1319,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else @@ -1325,7 +1327,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1338,15 +1340,15 @@ template static __device__ __forceinlin const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, threadIdx.x); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { @@ -1407,7 +1409,7 @@ template static __device__ __forceinlin x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1446,7 +1448,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); #else @@ -1454,7 +1456,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1478,16 +1480,16 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { @@ -1548,7 +1550,7 @@ template static __device__ __forceinlin x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template @@ -1587,7 +1589,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); @@ -1596,7 +1598,7 @@ template static __device__ __forceinlin int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -1619,13 +1621,13 @@ template static __device__ __forceinlin const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 @@ -1641,11 +1643,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } #pragma unroll @@ -1658,11 +1660,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); #else x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -1702,11 +1704,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; + typedef mma_A_I16K4 mma_A; + typedef mma_B_J8K4 mma_B; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; @@ -1732,8 +1734,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll @@ -1771,8 +1773,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( mma_B B[2]; float dB[mma_C::ne/2]; - B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + // Here load_generic is faster than load_ldmatrix. + B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { @@ -1784,8 +1787,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); + C[0].mma(A[n][k01/4 + 0], B[0]); + C[1].mma(A[n][k01/4 + 1], B[1]); #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { @@ -1805,20 +1808,20 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_NL; const int kqsx = threadIdx.x % QI4_NL; @@ -1836,13 +1839,13 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; @@ -1858,25 +1861,25 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_XXS/2); @@ -1905,36 +1908,36 @@ template static __device__ __forceinlin const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_XS/2); @@ -1959,38 +1962,38 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI2_S/2); @@ -2022,38 +2025,38 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI3_XXS/2); @@ -2080,36 +2083,36 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % (QI3_S/2); @@ -2143,36 +2146,36 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kqsx = threadIdx.x % QI1_S; @@ -2198,37 +2201,37 @@ template static __device__ __forceinlin const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE const int kbx = 0; // threadIdx.x / QI4_XS const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS @@ -2246,13 +2249,13 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } #pragma unroll @@ -2270,11 +2273,11 @@ template static __device__ __forceinlin const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE } } @@ -2307,16 +2310,16 @@ template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - typedef mma_int_C_I16J8 mma_C; + typedef mma_C_I16J8 mma_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { @@ -2505,13 +2508,13 @@ static __device__ void mul_mat_q_process_tile( int * tile_y = (int *) data_mul_mat_q; int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); -#ifdef INT8_MMA_AVAILABLE +#ifdef NEW_MMA_AVAILABLE constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // INT8_MMA_AVAILABLE +#endif // NEW_MMA_AVAILABLE constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2643,7 +2646,7 @@ static __global__ void mul_mat_q( const int jt = kbc / (blocks_per_ne00*nty); const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; - constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks. + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, it, jt, kb0_start, kb0_stop); @@ -2749,7 +2752,7 @@ template static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); - const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu new file mode 100644 index 0000000000000..f09bdeff79a43 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16); +DECL_FATTN_MMA_F16_CASE(80, 16); +DECL_FATTN_MMA_F16_CASE(96, 16); +DECL_FATTN_MMA_F16_CASE(112, 16); +DECL_FATTN_MMA_F16_CASE(128, 16); +DECL_FATTN_MMA_F16_CASE(256, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu new file mode 100644 index 0000000000000..221108873a0e1 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32); +DECL_FATTN_MMA_F16_CASE(80, 32); +DECL_FATTN_MMA_F16_CASE(96, 32); +DECL_FATTN_MMA_F16_CASE(112, 32); +DECL_FATTN_MMA_F16_CASE(128, 32); +DECL_FATTN_MMA_F16_CASE(256, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu new file mode 100644 index 0000000000000..d24b085758d25 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64); +DECL_FATTN_MMA_F16_CASE(80, 64); +DECL_FATTN_MMA_F16_CASE(96, 64); +DECL_FATTN_MMA_F16_CASE(112, 64); +DECL_FATTN_MMA_F16_CASE(128, 64); +DECL_FATTN_MMA_F16_CASE(256, 64); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu new file mode 100644 index 0000000000000..bdf86c0eabaf6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8); +DECL_FATTN_MMA_F16_CASE(80, 8); +DECL_FATTN_MMA_F16_CASE(96, 8); +DECL_FATTN_MMA_F16_CASE(112, 8); +DECL_FATTN_MMA_F16_CASE(128, 8); +DECL_FATTN_MMA_F16_CASE(256, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu deleted file mode 100644 index 2d94e65c28c29..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 16, float); -DECL_FATTN_WMMA_F16_CASE(80, 16, float); -DECL_FATTN_WMMA_F16_CASE(96, 16, float); -DECL_FATTN_WMMA_F16_CASE(112, 16, float); -DECL_FATTN_WMMA_F16_CASE(128, 16, float); -DECL_FATTN_WMMA_F16_CASE(256, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu deleted file mode 100644 index c3d9df3c44313..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ /dev/null @@ -1,9 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 32, float); -DECL_FATTN_WMMA_F16_CASE(80, 32, float); -DECL_FATTN_WMMA_F16_CASE(96, 32, float); -DECL_FATTN_WMMA_F16_CASE(112, 32, float); -DECL_FATTN_WMMA_F16_CASE(128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu deleted file mode 100644 index bb680e401f7da..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 16, half); -DECL_FATTN_WMMA_F16_CASE(80, 16, half); -DECL_FATTN_WMMA_F16_CASE(96, 16, half); -DECL_FATTN_WMMA_F16_CASE(112, 16, half); -DECL_FATTN_WMMA_F16_CASE(128, 16, half); -DECL_FATTN_WMMA_F16_CASE(256, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu deleted file mode 100644 index 073f71b1f3e26..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ /dev/null @@ -1,10 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 32, half); -DECL_FATTN_WMMA_F16_CASE(80, 32, half); -DECL_FATTN_WMMA_F16_CASE(96, 32, half); -DECL_FATTN_WMMA_F16_CASE(112, 32, half); -DECL_FATTN_WMMA_F16_CASE(128, 32, half); -DECL_FATTN_WMMA_F16_CASE(256, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu deleted file mode 100644 index d30710c5fa21b..0000000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ /dev/null @@ -1,8 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-wmma-f16.cuh" - -DECL_FATTN_WMMA_F16_CASE(64, 8, half); -DECL_FATTN_WMMA_F16_CASE(96, 8, half); -DECL_FATTN_WMMA_F16_CASE(128, 8, half); -DECL_FATTN_WMMA_F16_CASE(256, 8, half); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index d7874e6eaf832..a2628f16e57d1 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -12,13 +12,13 @@ DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); """ -SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. +SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. -#include "../fattn-wmma-f16.cuh" +#include "../fattn-mma-f16.cuh" """ -SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,20 +57,12 @@ def get_head_sizes(type_k, type_v): with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for kq_acc_t in ["half", "float"]: - for cols_per_block in [8, 16, 32]: - if kq_acc_t == "float" and cols_per_block == 8: - continue +for cols_per_block in [8, 16, 32, 64]: + with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f: + f.write(SOURCE_FATTN_MMA_START) - with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f: - f.write(SOURCE_FATTN_WMMA_START) - - for head_size in [64, 80, 96, 112, 128, 256]: - if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32 - continue - if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance - continue - f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size)) + for head_size in [64, 80, 96, 112, 128, 256]: + f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 7a877bdc11a6f..eb03e10fa48a1 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -50,7 +50,7 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") -file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 415b2b2e09c1b..2f555416e62cf 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS})