Skip to content

Commit

Permalink
Add small M support (#3682)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#758

add small m (m = 1, 2, 4) support for fast gemv

- bf16_fast_gemv [+]
- bf16fp8bf16_fast_gemv[+]
- fp8fp8bf16_fast_gemv[+]

Differential Revision: D69492556
  • Loading branch information
YUNQIUGUO authored and facebook-github-bot committed Feb 12, 2025
1 parent 35fd9d7 commit 71fb213
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {

check_if_valid_block_dimensions(m, n, k, block_dim);

dim3 grid_dim(1, n / block_dim.y);
dim3 grid_dim(m, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -62,6 +62,8 @@ at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W) {
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
m,
n,
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double w_scale, double w_zp) {

check_if_valid_block_dimensions(m, n, k, block_dim);

dim3 grid_dim(1, n / block_dim.y);
dim3 grid_dim(m, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -65,6 +65,8 @@ bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double w_scale, double w_zp) {
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
m,
n,
__float2half(float(w_scale)),
__float2half(float(w_zp)),
num_per_thread);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double scale, double zp) {

check_if_valid_block_dimensions(m, n, k, block_dim);

dim3 grid_dim(1, n / block_dim.y);
dim3 grid_dim(m, n / block_dim.y);
unsigned int num_per_thread = k / block_dim.x;

auto stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -65,6 +65,8 @@ fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double scale, double zp) {
reinterpret_cast<cutlass::float_e4m3_t*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
m,
n,
__float2half(scale),
__float2half(zp),
num_per_thread);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ namespace fbgemm_gpu {
namespace {

void check_if_valid_block_dimensions(int m, int n, int k, dim3 block_dim) {
TORCH_CHECK(
m > 4,
"Invalid value for m: m (",
m,
") must be greater than 4. The kernel cannot be run with the current value of m."
" Please use an `m` smaller or equal to 4.")
TORCH_CHECK(
n % block_dim.y == 0,
"Invalid block dimensions: n (",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,25 @@ __global__ void gemv_bf16(
__nv_bfloat16* mat,
__nv_bfloat16* vec,
__nv_bfloat16* res,
unsigned int k,
unsigned int m,
unsigned int n,
unsigned int num_per_thread) {
float sum = 0;
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int col = blockIdx.x;
unsigned int start_idx = threadIdx.x;
float4* mat4 = reinterpret_cast<float4*>(mat);
float4* vec4 = reinterpret_cast<float4*>(vec);

#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
if (j < n >> 3) {
float4 vec_val = vec4[j];
float4 mat_val = mat4[row * (n >> 3) + j];
if (j < k >> 3) {
float4 vec_val = vec4[col * (k >> 3) + j];
float4 mat_val = mat4[row * (k >> 3) + j];
const bfloat16_2* vec_h1 = (bfloat16_2*)&vec_val.x;
const bfloat16_2* vec_h2 = (bfloat16_2*)&vec_val.y;
const bfloat16_2* vec_h3 = (bfloat16_2*)&vec_val.z;
Expand Down Expand Up @@ -108,7 +111,7 @@ __global__ void gemv_bf16(

if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
res[row] = __float2bfloat16(sum);
res[row + blockIdx.x * n] = __float2bfloat16(sum);
}
return;
}
Expand All @@ -128,7 +131,7 @@ __global__ void gemv_bf16(
if (warpId == 0)
sum = warpReduceSum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
res[row] = __float2bfloat16(sum);
res[row + blockIdx.x * n] = __float2bfloat16(sum);
}
}

Expand All @@ -139,6 +142,8 @@ __global__ void gemv_quantized_bf16_fp8(
cutlass::float_e4m3_t* mat,
__nv_bfloat16* vec,
__nv_bfloat16* res,
unsigned int k,
unsigned int m,
unsigned int n,
half scale,
half zero_point,
Expand All @@ -147,6 +152,7 @@ __global__ void gemv_quantized_bf16_fp8(
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int col = blockIdx.x;
unsigned int start_idx = threadIdx.x;
half4* mat4 = reinterpret_cast<half4*>(mat);
float4* vec4 = reinterpret_cast<float4*>(vec);
Expand All @@ -157,9 +163,9 @@ __global__ void gemv_quantized_bf16_fp8(
#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
if (j < n >> 3) {
float4 vec_val = vec4[j];
half4 mat_val = mat4[row * (n >> 3) + j];
if (j < k >> 3) {
float4 vec_val = vec4[col * (k >> 3) + j];
half4 mat_val = mat4[row * (k >> 3) + j];
const bfloat16_2* vec_h1 = (bfloat16_2*)&vec_val.x;
const bfloat16_2* vec_h2 = (bfloat16_2*)&vec_val.y;
const bfloat16_2* vec_h3 = (bfloat16_2*)&vec_val.z;
Expand Down Expand Up @@ -217,7 +223,7 @@ __global__ void gemv_quantized_bf16_fp8(

if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
res[row] = __float2bfloat16(sum);
res[row + blockIdx.x * n] = __float2bfloat16(sum);
}
return;
}
Expand All @@ -237,7 +243,7 @@ __global__ void gemv_quantized_bf16_fp8(
if (warpId == 0)
sum = warpReduceSum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
res[row] = __float2bfloat16(sum);
res[row + blockIdx.x * n] = __float2bfloat16(sum);
}
}

Expand All @@ -247,6 +253,8 @@ __global__ void gemv_quantized_fp8_fp8(
cutlass::float_e4m3_t* mat,
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
unsigned int k,
unsigned int m,
unsigned int n,
half scale,
half zero_point,
Expand All @@ -255,6 +263,7 @@ __global__ void gemv_quantized_fp8_fp8(
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int col = blockIdx.x;
unsigned int start_idx = threadIdx.x;
half4* mat4 = reinterpret_cast<half4*>(mat);
half4* vec4 = reinterpret_cast<half4*>(vec);
Expand All @@ -266,9 +275,9 @@ __global__ void gemv_quantized_fp8_fp8(
#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
if (j < n >> 3) {
half4 vec_val = vec4[j];
half4 mat_val = mat4[row * (n >> 3) + j];
if (j < k >> 3) {
half4 vec_val = vec4[col * (k >> 3) + j];
half4 mat_val = mat4[row * (k >> 3) + j];
const fp8_2* vec_h1 = (fp8_2*)&vec_val.x;
const fp8_2* vec_h2 = (fp8_2*)&vec_val.y;
const fp8_2* vec_h3 = (fp8_2*)&vec_val.z;
Expand Down Expand Up @@ -334,7 +343,7 @@ __global__ void gemv_quantized_fp8_fp8(

if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
res[row] = __float2bfloat16(sum);
res[row + blockIdx.x * n] = __float2bfloat16(sum);
}
return;
}
Expand All @@ -354,7 +363,7 @@ __global__ void gemv_quantized_fp8_fp8(
if (warpId == 0)
sum = warpReduceSum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
res[row] = __float2bfloat16(sum);
res[row + blockIdx.x * n] = __float2bfloat16(sum);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ __global__ void gemv_bf16(
__nv_bfloat16* mat,
__nv_bfloat16* vec,
__nv_bfloat16* res,
unsigned int k,
unsigned int m,
unsigned int n,
unsigned int num_per_thread);

__global__ void gemv_quantized_bf16_fp8(
cutlass::float_e4m3_t* mat,
__nv_bfloat16* vec,
__nv_bfloat16* res,
unsigned int k,
unsigned int m,
unsigned int n,
half scale,
half zero_point,
Expand All @@ -71,6 +75,8 @@ __global__ void gemv_quantized_fp8_fp8(
cutlass::float_e4m3_t* mat,
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
unsigned int k,
unsigned int m,
unsigned int n,
half scale,
half zero_point,
Expand Down
34 changes: 30 additions & 4 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,18 @@ def test_bf16_gemv(self) -> None:
(1, 8192, 1024),
(1, 7168, 8192),
(1, 8192, 3584),
(2, 128, 256),
(2, 256, 256),
(2, 1280, 8192),
(2, 8192, 1024),
(2, 7168, 8192),
(2, 8192, 3584),
(4, 128, 256),
(4, 256, 256),
(4, 1280, 8192),
(4, 8192, 1024),
(4, 7168, 8192),
(4, 8192, 3584),
]
self.test_gemv(test_cases, torch.ops.fbgemm.bf16_fast_gemv, 9.0e-3, 9.0e-3)

Expand All @@ -1157,18 +1169,24 @@ def test_bf16_gemv(self) -> None:
)
def test_bf16_fp8_gemv(self) -> None:
test_cases = [
(1, 128, 256),
(1, 256, 256),
(1, 1280, 8192),
(1, 8192, 1024),
(1, 7168, 8192),
(1, 8192, 3584),
(2, 1280, 8192),
(2, 8192, 1024),
(2, 7168, 8192),
(2, 8192, 3584),
(4, 1280, 8192),
(4, 8192, 1024),
(4, 7168, 8192),
(4, 8192, 3584),
]
self.test_gemv(
test_cases,
torch.ops.fbgemm.bf16fp8bf16_fast_gemv,
1.0e-2,
1.0e-2,
9.0e-2,
9.0e-2,
quantize_w=True,
)

Expand All @@ -1181,6 +1199,14 @@ def test_fp8_fp8_gemv(self) -> None:
(1, 8192, 1024),
(1, 7168, 8192),
(1, 8192, 3584),
(2, 1280, 8192),
(2, 8192, 1024),
(2, 7168, 8192),
(2, 8192, 3584),
(4, 1280, 8192),
(4, 8192, 1024),
(4, 7168, 8192),
(4, 8192, 3584),
]
self.test_gemv(
test_cases,
Expand Down

0 comments on commit 71fb213

Please sign in to comment.