Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Core] Add AWQ support to the Marlin kernel #6612

Merged
merged 25 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
Expand Down
12 changes: 8 additions & 4 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,19 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);

torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
Expand Down
33 changes: 15 additions & 18 deletions csrc/quantization/fp8/fp8_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
* Adapted from https://github.com/IST-DASLab/marlin
*/

#include "../gptq_marlin/gptq_marlin.cuh"
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"

using namespace gptq_marlin;
using namespace marlin;

#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
Expand Down Expand Up @@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k);

// Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size);
int actual_size_n =
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
" is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);

Expand Down Expand Up @@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
num_groups = b_scales.size(0);

// Verify workspace size
TORCH_CHECK(
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", marlin::min_thread_n);
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
Expand All @@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
marlin::max_par);
} else {
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
}
Expand Down
269 changes: 269 additions & 0 deletions csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
#include "marlin.cuh"

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800

namespace marlin {

template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}

} // namespace marlin

torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}

#else

namespace marlin {

template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different from GPTQ? It looks similar to me at a glance

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal unpacking is different: AWQ packs over columns, while GPTQ over rows, and also AWQ performs the interleaving of groups of 8 (for 4-bit) or groups of 4 (for 8-bit) to be compatible to the de-quantization PTX assembly.

uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;

int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);

int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}

int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);

// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};

extern __shared__ int4 sh[];

constexpr int tile_n_ints = tile_n_size / pack_factor;

constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;

auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}

int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;

int4* sh_ptr = sh + stage_size * pipe;

if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;

int first_k = k_tile_id * tile_k_size;

cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
first_n_packed + (n_id * 4)])));
}

cp_async_fence();
};

auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}

int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;

if (warp_id >= 4) {
return;
}

int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;

constexpr int tc_offsets[4] = {0, 1, 8, 9};

int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;

constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;

int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);

// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}

uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];

int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];

vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}

constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;

// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};

uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}

out_ptr[out_offset + th_id * 4 + warp_id] = res;

} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};

uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}

out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};

auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}

wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;

start_pipes(k_tile_id, n_tile_id);

while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}

} // namespace marlin

#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}

torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);

TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;

// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k,
"b_q_weight.size(0) = ", b_q_weight.size(0),
" is not size_k = ", size_k);
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_n = ", size_n, ", pack_factor = ", pack_factor);

// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");

// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);

// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());

// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);

int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);

if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}

return out;
}

#endif
Loading
Loading