From 314bdec788357507cb0c6b1931b6219fd867d712 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 20 Jun 2024 19:03:21 +0530 Subject: [PATCH] [Kernel] Update Cutlass int8 kernel configs for SM80 (#5275) Co-authored-by: Varun Sundar Rabindranath --- csrc/quantization/cutlass_w8a8/common.hpp | 7 + .../cutlass_w8a8/scaled_mm_c2x.cu | 127 ++++++++++++++++-- .../cutlass_w8a8/scaled_mm_c3x.cu | 5 - 3 files changed, 123 insertions(+), 16 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp index 999b7b251ab33..23d0587bbdc5d 100644 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -1,6 +1,7 @@ #pragma once #include "cutlass/cutlass.h" +#include /** * Helper function for checking CUTLASS errors @@ -10,3 +11,9 @@ TORCH_CHECK(status == cutlass::Status::kSuccess, \ cutlassGetStatusString(status)) \ } + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 7651268dc5316..740b9fb64a754 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -250,8 +250,120 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(status); } +template typename Epilogue> +struct sm80_config_default { + // This config is used in 2 cases, + // - M in (128, inf) + // - M in (64, 128] and N >= 8192 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm80_config_M64 { + // This config is used in 2 cases, + // - M in (32, 64] + // - M in (64, 128] and N < 8192 + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm80_config_M32 { + // M in (16, 32] + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + +template typename Epilogue> +struct sm80_config_M16 { + // M in [1, 16] + static_assert(std::is_same()); + using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; + using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + using Cutlass2xGemm = + cutlass_2x_gemm; +}; + } // namespace +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass2xGemmDefault = + typename sm80_config_default::Cutlass2xGemm; + using Cutlass2xGemmM128BigN = + typename sm80_config_default::Cutlass2xGemm; + using Cutlass2xGemmM128SmallN = + typename sm80_config_M64::Cutlass2xGemm; + using Cutlass2xGemmM64 = + typename sm80_config_M64::Cutlass2xGemm; + using Cutlass2xGemmM32 = + typename sm80_config_M32::Cutlass2xGemm; + using Cutlass2xGemmM16 = + typename sm80_config_M16::Cutlass2xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(16), next_pow_2(m)); // next power of 2 + if (mp2 <= 16) { + // M in [1, 16] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 32) { + // M in (16, 32] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // M in (32, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // M in (64, 128] + uint32_t const n = out.size(1); + bool const small_n = n < 8192; + if (small_n) { + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } + } else { + // M in (128, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -288,20 +400,13 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; - using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; - if (out.dtype() == torch::kBFloat16) { - return cutlass_gemm_caller>( - out, a, b, a_scales, b_scales); + return cutlass_gemm_sm80_dispatch(out, a, b, a_scales, + b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_gemm_caller>( + return cutlass_gemm_sm80_dispatch( out, a, b, a_scales, b_scales); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 8f2aa9425a029..cfa8f80f7ea04 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -44,11 +44,6 @@ using namespace cute; namespace { -uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - // A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary.