Skip to content

Commit

Permalink
replace cublaslt with custom matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Jun 12, 2024
1 parent 7f3f3ac commit d679364
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 139 deletions.
104 changes: 54 additions & 50 deletions dev/cuda/matmul_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ __global__ void add_bias(float* out, const float* bias, int B, int T, int OC) {
}
}

// kernel 4: register reuse kernel
// kernel 4: semi-efficient handwritten kernel
// see trimat_forward.cu for some intermediate development steps
__device__ float4 ld_vec(const float* address) {
return *reinterpret_cast<const float4*>(address);
}
Expand All @@ -95,69 +96,72 @@ __device__ void st_vec(float* address, float4 val) {

__global__ void __launch_bounds__(16*16) matmul_forward_kernel4(float* out,
const float* inp, const float* weight, const float* bias,
int BT, int C, int OC) {
int C, int OC) {
// out is (B,T,OC). OC is short for "output channels", e.g. OC = 4 * C
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// in the naive kernel, every thread handles one element of out
int bt = 8*(blockIdx.x * blockDim.x + threadIdx.x);
// each thread handles 8x8 elements; each block 128 by 128 elements.
int oc = 8*(blockIdx.y * blockDim.y + threadIdx.y);

// buffers to cache chunks of the input matrices
__shared__ float lhs_s[128][32];
__shared__ float rhs_s[128][32];

// adjust our pointers for the current block
inp += 128 * blockIdx.x * C;
weight += 128 * blockIdx.y * C;
out += 128 * blockIdx.x * OC + 128 * blockIdx.y;

if (bt < BT && oc < OC) {
float vals[8][8] = {};
if(bias != NULL) {
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j += 4) {
float4 b = ld_vec(bias + oc + j);
vals[i][j+0] = b.x;
vals[i][j+1] = b.y;
vals[i][j+2] = b.z;
vals[i][j+3] = b.w;
}
float vals[8][8] = {};
if(bias != NULL) {
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 8; j += 4) {
float4 b = ld_vec(bias + oc + j);
vals[i][j+0] = b.x;
vals[i][j+1] = b.y;
vals[i][j+2] = b.z;
vals[i][j+3] = b.w;
}
}
}

int si_start = 16 * threadIdx.y + threadIdx.x;
for (int so = 0; so < C; so += 32) {
__syncthreads();
int xmod8 = threadIdx.x % 8;
int xby8 = threadIdx.x / 8;
for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) {
int xo = 4 * xmod8;
st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo));
st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo));
}
__syncthreads();
int si_start = 4*(16 * threadIdx.y + threadIdx.x);
for (int so = 0; so < C; so += 32) {
__syncthreads();
int xmod8 = threadIdx.x % 8;
int xby8 = threadIdx.x / 8;
int xo = 4 * xmod8;
for(int y = 2 * threadIdx.y + xby8; y < 128; y += 32) {
st_vec(&lhs_s[y][xo], ld_vec(inp + y * C + so + xo));
st_vec(&rhs_s[y][xo], ld_vec(weight + y * C + so + xo));
}
__syncthreads();

for (int si = si_start; si < si_start + 32; ++si) {
float rhs[8];
for (int u = 0; u < 8; ++u) {
rhs[u] = rhs_s[u + 8 * threadIdx.y][si % 32];
}
for (int si = si_start; si < si_start + 32; si += 4) {
float4 rhs[8];
for (int u = 0; u < 8; ++u) {
rhs[u] = ld_vec(&rhs_s[u + 8 * threadIdx.y][si % 32]);
}

for (int ii = 0; ii < 8; ++ii) {
float lhs = lhs_s[ii + 8 * threadIdx.x][si % 32];
for (int ji = 0; ji < 8; ++ji) {
vals[ii][ji] += lhs * rhs[ji];
}
for (int ii = 0; ii < 8; ++ii) {
float4 lhs = ld_vec(&lhs_s[ii + 8 * threadIdx.x][si % 32]);
for (int ji = 0; ji < 8; ++ji) {
vals[ii][ji] += lhs.x * rhs[ji].x;
vals[ii][ji] += lhs.y * rhs[ji].y;
vals[ii][ji] += lhs.z * rhs[ji].z;
vals[ii][ji] += lhs.w * rhs[ji].w;
}
}
}
for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; j += 4) {
float4 result;
result.x = vals[i][j + 0];
result.y = vals[i][j + 1];
result.z = vals[i][j + 2];
result.w = vals[i][j + 3];
st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result);
}
}

for (int i = 0; i < 8; ++i) {
for (int j = 0; j < 8; j += 4) {
float4 result;
result.x = vals[i][j + 0];
result.y = vals[i][j + 1];
result.z = vals[i][j + 2];
result.w = vals[i][j + 3];
st_vec(out + (8*threadIdx.x+i) * OC + 8*threadIdx.y + j, result);
}
}
}
Expand Down Expand Up @@ -296,7 +300,7 @@ void matmul_forward3(float* out,
cublasCheck(cublasLtMatrixLayoutDestroy(biasLayout));
}

// kernel 1 is the most naive matmul kernel
// handwritten, relatively efficient non-tensorcore matmul kernel
void matmul_forward4(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC,
Expand All @@ -305,9 +309,9 @@ void matmul_forward4(float* out,
// inp is (B,T,C), weight is (OC, C), bias is (OC)
sqrt_block_size = 16;

dim3 gridDim(ceil_div(B * T, sqrt_block_size), ceil_div(OC, sqrt_block_size));
dim3 gridDim(ceil_div(B * T, 8*sqrt_block_size), ceil_div(OC, 8*sqrt_block_size));
dim3 blockDim(sqrt_block_size, sqrt_block_size);
matmul_forward_kernel4<<<gridDim, blockDim>>>(out, inp, weight, bias, B*T, C, OC);
matmul_forward_kernel4<<<gridDim, blockDim>>>(out, inp, weight, bias, C, OC);
cudaCheck(cudaGetLastError());
}

Expand Down Expand Up @@ -341,7 +345,7 @@ void matmul_forward(int kernel_num,
int main(int argc, char **argv) {
srand(0);

int B = 8;
int B = 32;
int T = 1024;
int C = 768;
int OC = 768 * 4; // expansion of 4, e.g. in the MLP
Expand All @@ -357,7 +361,7 @@ int main(int argc, char **argv) {
cublasCheck(cublasCreate(&cublas_handle));
cublasCheck(cublasLtCreate(&cublaslt_handle));
// TF32 precision is equivalent to torch.set_float32_matmul_precision('high')
int enable_tf32 = false; //deviceProp.major >= 8 ? 1 : 0;
int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;
printf("enable_tf32: %d\n", enable_tf32);
cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;
cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
Expand Down
4 changes: 0 additions & 4 deletions test_gpt2_fp32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ int main(int argc, char *argv[]) {

// setup cuBLAS and cuBLASLt
cublasCheck(cublasCreate(&cublas_handle));
cublasCheck(cublasLtCreate(&cublaslt_handle));
// TF32 precision is equivalent to torch.set_float32_matmul_precision('high')
int enable_tf32 = deviceProp.major >= 8 ? 1 : 0;
enable_tf32 = 0; // NOTE: disable TF32 for testing!!!
printf("enable_tf32: %d\n", enable_tf32);
cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;
cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));
cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));

// build the GPT-2 model from a checkpoint
GPT2 model;
Expand Down Expand Up @@ -231,9 +229,7 @@ int main(int argc, char *argv[]) {
free(expected_grads_memory);
free(calculated_grads_memory);
gpt2_free(&model);
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
cublasCheck(cublasLtDestroy(cublaslt_handle));

return 0;
}
Loading

0 comments on commit d679364

Please sign in to comment.