From d6793645f5e7717105274ee587ebb462ff6fe194 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 12 Jun 2024 09:35:42 +0300 Subject: [PATCH] replace cublaslt with custom matmul --- dev/cuda/matmul_forward.cu | 104 +++++++++++---------- test_gpt2_fp32.cu | 4 - train_gpt2_fp32.cu | 183 ++++++++++++++++++++----------------- 3 files changed, 152 insertions(+), 139 deletions(-) diff --git a/dev/cuda/matmul_forward.cu b/dev/cuda/matmul_forward.cu index 46ba1e1ea..1a03e9a2c 100644 --- a/dev/cuda/matmul_forward.cu +++ b/dev/cuda/matmul_forward.cu @@ -84,7 +84,8 @@ __global__ void add_bias(float* out, const float* bias, int B, int T, int OC) { } } -// kernel 4: register reuse kernel +// kernel 4: semi-efficient handwritten kernel +// see trimat_forward.cu for some intermediate development steps __device__ float4 ld_vec(const float* address) { return *reinterpret_cast(address); } @@ -95,69 +96,72 @@ __device__ void st_vec(float* address, float4 val) { __global__ void __launch_bounds__(16*16) matmul_forward_kernel4(float* out, const float* inp, const float* weight, const float* bias, - int BT, int C, int OC) { + int C, int OC) { // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C // inp is (B,T,C), weight is (OC, C), bias is (OC) - // in the naive kernel, every thread handles one element of out - int bt = 8*(blockIdx.x * blockDim.x + threadIdx.x); + // each thread handles 8x8 elements; each block 128 by 128 elements. int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y); + // buffers to cache chunks of the input matrices __shared__ float lhs_s[128][32]; __shared__ float rhs_s[128][32]; + // adjust our pointers for the current block inp += 128 * blockIdx.x * C; weight += 128 * blockIdx.y * C; out += 128 * blockIdx.x * OC + 128 * blockIdx.y; - if (bt < BT && oc < OC) { - float vals[8][8] = {}; - if(bias != NULL) { - for (int i = 0; i < 8; i++) { - for (int j = 0; j < 8; j += 4) { - float4 b = ld_vec(bias + oc + j); - vals[i][j+0] = b.x; - vals[i][j+1] = b.y; - vals[i][j+2] = b.z; - vals[i][j+3] = b.w; - } + float vals[8][8] = {}; + if(bias != NULL) { + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j += 4) { + float4 b = ld_vec(bias + oc + j); + vals[i][j+0] = b.x; + vals[i][j+1] = b.y; + vals[i][j+2] = b.z; + vals[i][j+3] = b.w; } } + } - int si_start = 16 * threadIdx.y + threadIdx.x; - for (int so = 0; so < C; so += 32) { - __syncthreads(); - int xmod8 = threadIdx.x % 8; - int xby8 = threadIdx.x / 8; - for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) { - int xo = 4 * xmod8; - st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo)); - st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo)); - } - __syncthreads(); + int si_start = 4*(16 * threadIdx.y + threadIdx.x); + for (int so = 0; so < C; so += 32) { + __syncthreads(); + int xmod8 = threadIdx.x % 8; + int xby8 = threadIdx.x / 8; + int xo = 4 * xmod8; + for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) { + st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo)); + st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo)); + } + __syncthreads(); - for (int si = si_start; si < si_start + 32; ++si) { - float rhs[8]; - for (int u = 0; u < 8; ++u) { - rhs[u] = rhs_s[u + 8 * threadIdx.y][si % 32]; - } + for (int si = si_start; si < si_start + 32; si += 4) { + float4 rhs[8]; + for (int u = 0; u < 8; ++u) { + rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]); + } - for (int ii = 0; ii < 8; ++ii) { - float lhs = lhs_s[ii + 8 * threadIdx.x][si % 32]; - for (int ji = 0; ji < 8; ++ji) { - vals[ii][ji] += lhs * rhs[ji]; - } + for (int ii = 0; ii < 8; ++ii) { + float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]); + for (int ji = 0; ji < 8; ++ji) { + vals[ii][ji] += lhs.x * rhs[ji].x; + vals[ii][ji] += lhs.y * rhs[ji].y; + vals[ii][ji] += lhs.z * rhs[ji].z; + vals[ii][ji] += lhs.w * rhs[ji].w; } } } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; j += 4) { - float4 result; - result.x = vals[i][j + 0]; - result.y = vals[i][j + 1]; - result.z = vals[i][j + 2]; - result.w = vals[i][j + 3]; - st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result); - } + } + + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; j += 4) { + float4 result; + result.x = vals[i][j + 0]; + result.y = vals[i][j + 1]; + result.z = vals[i][j + 2]; + result.w = vals[i][j + 3]; + st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result); } } } @@ -296,7 +300,7 @@ void matmul_forward3(float* out, cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout)); } -// kernel 1 is the most naive matmul kernel +// handwritten, relatively efficient non-tensorcore matmul kernel void matmul_forward4(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC, @@ -305,9 +309,9 @@ void matmul_forward4(float* out, // inp is (B,T,C), weight is (OC, C), bias is (OC) sqrt_block_size = 16; - dim3 gridDim(ceil_div(B * T, sqrt_block_size), ceil_div(OC, sqrt_block_size)); + dim3 gridDim(ceil_div(B * T, 8*sqrt_block_size), ceil_div(OC, 8*sqrt_block_size)); dim3 blockDim(sqrt_block_size, sqrt_block_size); - matmul_forward_kernel4<<>>(out, inp, weight, bias, B*T, C, OC); + matmul_forward_kernel4<<>>(out, inp, weight, bias, C, OC); cudaCheck(cudaGetLastError()); } @@ -341,7 +345,7 @@ void matmul_forward(int kernel_num, int main(int argc, char **argv) { srand(0); - int B = 8; + int B = 32; int T = 1024; int C = 768; int OC = 768 * 4; // expansion of 4, e.g. in the MLP @@ -357,7 +361,7 @@ int main(int argc, char **argv) { cublasCheck(cublasCreate(&cublas_handle)); cublasCheck(cublasLtCreate(&cublaslt_handle)); // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') - int enable_tf32 = false; //deviceProp.major >= 8 ? 1 : 0; + int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; printf("enable_tf32: %d\n", enable_tf32); cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; diff --git a/test_gpt2_fp32.cu b/test_gpt2_fp32.cu index 01440072a..e1b6af3d8 100644 --- a/test_gpt2_fp32.cu +++ b/test_gpt2_fp32.cu @@ -36,7 +36,6 @@ int main(int argc, char *argv[]) { // setup cuBLAS and cuBLASLt cublasCheck(cublasCreate(&cublas_handle)); - cublasCheck(cublasLtCreate(&cublaslt_handle)); // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; enable_tf32 = 0; // NOTE: disable TF32 for testing!!! @@ -44,7 +43,6 @@ int main(int argc, char *argv[]) { cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); - cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); // build the GPT-2 model from a checkpoint GPT2 model; @@ -231,9 +229,7 @@ int main(int argc, char *argv[]) { free(expected_grads_memory); free(calculated_grads_memory); gpt2_free(&model); - cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); - cublasCheck(cublasLtDestroy(cublaslt_handle)); return 0; } \ No newline at end of file diff --git a/train_gpt2_fp32.cu b/train_gpt2_fp32.cu index 6553ca009..ebd7c9257 100644 --- a/train_gpt2_fp32.cu +++ b/train_gpt2_fp32.cu @@ -23,7 +23,6 @@ the layernorms are connected to the residuals so we += in layernorm backward. // GPU / CUDA related #include #include -#include #include #include // our own utilities @@ -60,12 +59,8 @@ void cublasCheck(cublasStatus_t status, const char *file, int line) } #define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); } -// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK -static size_t cublaslt_workspace_size = 32 * 1024 * 1024; -static void* cublaslt_workspace = NULL; static cublasComputeType_t cublas_compute_type; cublasHandle_t cublas_handle; -cublasLtHandle_t cublaslt_handle; namespace cg = cooperative_groups; @@ -611,6 +606,87 @@ __global__ void fused_classifier_kernel3(float* logits, float* losses, float* pr } } +__device__ float4 ld_vec(const float* address) { + return *reinterpret_cast(address); +} + +__device__ void st_vec(float* address, float4 val) { + *reinterpret_cast(address) = val; +} + +__global__ void __launch_bounds__(16*16, 2) matmul_forward_kernel4(float* out, + const float* inp, const float* weight, const float* bias, + int C, int OC) { + // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C + // inp is (B,T,C), weight is (OC, C), bias is (OC) + // each thread handles 8x8 elements; each block 128 by 128 elements. + int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y); + + // buffers to cache chunks of the input matrices + __shared__ float lhs_s[128][32]; + __shared__ float rhs_s[128][32]; + + // adjust our pointers for the current block + inp += 128 * blockIdx.x * C; + weight += 128 * blockIdx.y * C; + out += 128 * blockIdx.x * OC + 128 * blockIdx.y; + + float vals[8][8] = {}; + if(bias != NULL) { + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j += 4) { + float4 b = ld_vec(bias + oc + j); + vals[i][j+0] = b.x; + vals[i][j+1] = b.y; + vals[i][j+2] = b.z; + vals[i][j+3] = b.w; + } + } + } + + int si_start = 4*(16 * threadIdx.y + threadIdx.x); + for (int so = 0; so < C; so += 32) { + __syncthreads(); + int xmod8 = threadIdx.x % 8; + int xby8 = threadIdx.x / 8; + int xo = 4 * xmod8; + for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) { + st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo)); + st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo)); + } + __syncthreads(); + + for (int si = si_start; si < si_start + 32; si += 4) { + float4 rhs[8]; + for (int u = 0; u < 8; ++u) { + rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]); + } + + for (int ii = 0; ii < 8; ++ii) { + float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]); + for (int ji = 0; ji < 8; ++ji) { + vals[ii][ji] += lhs.x * rhs[ji].x; + vals[ii][ji] += lhs.y * rhs[ji].y; + vals[ii][ji] += lhs.z * rhs[ji].z; + vals[ii][ji] += lhs.w * rhs[ji].w; + } + } + } + } + + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; j += 4) { + float4 result; + result.x = vals[i][j + 0]; + result.y = vals[i][j + 1]; + result.z = vals[i][j + 2]; + result.w = vals[i][j + 3]; + st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result); + } + } +} + + // ---------------------------------------------------------------------------- // kernel launchers @@ -645,77 +721,18 @@ void layernorm_forward(float* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } -// uses cuBLASLt to fuse the bias and gelu. does not work with OC = 50257 (last layer) -// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul -// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtSgemm/sample_cublasLt_LtSgemm.cu -void matmul_forward_cublaslt(float* out, - float* inp, float* weight, float* bias, - int B, int T, int C, int OC) { - int has_bias = (bias != NULL); - - // check bias alignment - if(((uintptr_t)bias % 16) != 0) { - printf("Bias pointer is not aligned (cuBLASLt requirement)!\n"); - exit(EXIT_FAILURE); - } - - int returnedResults = 0; - cublasLtMatmulDesc_t operationDesc; - cublasLtMatmulPreference_t preference; - cublasLtMatrixLayout_t weightLayout; - cublasLtMatrixLayout_t inputLayout; - cublasLtMatrixLayout_t outputLayout; - cublasLtMatrixLayout_t biasLayout; - cublasLtMatmulHeuristicResult_t heuristic; - - // create the operation descriptor - cublasOperation_t opNoTranspose = CUBLAS_OP_N; - cublasOperation_t opTranspose = CUBLAS_OP_T; - cublasLtEpilogue_t epilogueBias = CUBLASLT_EPILOGUE_BIAS; - cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute_type, CUDA_R_32F)); - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose))); - if(has_bias) { - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogueBias, - sizeof(epilogueBias))); - } - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); - - // define matrix layouts - cublasCheck(cublasLtMatrixLayoutCreate(&weightLayout, CUDA_R_32F, C, OC, C)); - cublasCheck(cublasLtMatrixLayoutCreate(&inputLayout, CUDA_R_32F, C, B*T, C)); - cublasCheck(cublasLtMatrixLayoutCreate(&outputLayout, CUDA_R_32F, OC, B*T, OC)); - cublasCheck(cublasLtMatrixLayoutCreate(&biasLayout, CUDA_R_32F, OC, 1, OC)); - - // create a preference handle with specified max workspace - cublasCheck(cublasLtMatmulPreferenceCreate(&preference)); - cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &cublaslt_workspace_size, sizeof(cublaslt_workspace_size))); - - // find a suitable algorithm - cublasCheck(cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, - weightLayout, inputLayout, outputLayout, outputLayout, - preference, 1, &heuristic, &returnedResults)); - if (returnedResults == 0) { - printf("No cuBLASLt algorithm: B: %d, T: %d, C: %d, OC: %d, bias: %d\n", B, T, C, OC, has_bias); - exit(EXIT_FAILURE); - } - - // call the matmul - const float alpha = 1.0f, beta = 0.0f; - cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc, - &alpha, weight, weightLayout, inp, inputLayout, &beta, - out, outputLayout, out, outputLayout, &heuristic.algo, - cublaslt_workspace, cublaslt_workspace_size, 0)); - - // cleanups - cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); - cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); - cublasCheck(cublasLtMatrixLayoutDestroy(weightLayout)); - cublasCheck(cublasLtMatrixLayoutDestroy(inputLayout)); - cublasCheck(cublasLtMatrixLayoutDestroy(outputLayout)); - cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout)); +// kernel 1 is the most naive matmul kernel +void matmul_forward(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC) { + // out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C + // inp is (B,T,C), weight is (OC, C), bias is (OC) + int sqrt_block_size = 16; + + dim3 gridDim(CEIL_DIV(B * T, 8*sqrt_block_size), CEIL_DIV(OC, 8*sqrt_block_size)); + dim3 blockDim(sqrt_block_size, sqrt_block_size); + matmul_forward_kernel4<<>>(out, inp, weight, bias, C, OC); + cudaCheck(cudaGetLastError()); } void attention_forward(float* out, float* qkvr, float* att, @@ -1255,20 +1272,20 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { // now do the forward pass layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); - matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); + matmul_forward(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH); - matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); + matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); residual_forward(l_residual2, residual, l_attproj, B*T*C); layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); - matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); + matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); gelu_forward(l_fch_gelu, l_fch, B*T*4*C); - matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); + matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C); } residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); - matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp); + matmul_forward(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp); // also forward the cross-entropy loss function if we have the targets if (targets != NULL) { @@ -1594,13 +1611,11 @@ int main(int argc, char *argv[]) { cudaGetDeviceProperties(&deviceProp, deviceIdx); // setup cuBLAS and cuBLASLt cublasCheck(cublasCreate(&cublas_handle)); - cublasCheck(cublasLtCreate(&cublaslt_handle)); // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') int enable_tf32 = deviceProp.major >= 8 ? 1 : 0; cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); - cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); printf("| device | %-50s |\n", deviceProp.name); printf("| TF32 | %-50s |\n", enable_tf32 ? "enabled" : "disabled"); printf("+-----------------------+----------------------------------------------------+\n"); @@ -1732,9 +1747,7 @@ int main(int argc, char *argv[]) { gpt2_free(&model); free(cpu_logits); free(gen_tokens); - cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); - cublasCheck(cublasLtDestroy(cublaslt_handle)); logger_free(&logger); return 0;