From 794bdda1ea2d62d4d2c0e858553058ad890ee5e3 Mon Sep 17 00:00:00 2001 From: Shanli Xing Date: Tue, 8 Oct 2024 17:22:46 -0700 Subject: [PATCH] feat: support sm90 cutlass group gemm (#509) Co-authored-by: Zihao Ye --- flashinfer-aot/csrc_aot/flashinfer_ops.cu | 1 - .../csrc_aot/flashinfer_sm90_ops.cu | 26 ++ flashinfer-aot/setup.py | 18 +- include/flashinfer/gemm/group_gemm.cuh | 4 +- .../flashinfer/gemm/group_gemm_cutlass.cuh | 57 +++- include/flashinfer/gemm/group_gemm_sm90.cuh | 248 ++++++++++++++++++ python/csrc/flashinfer_gemm_sm90_ops.cu | 28 ++ python/csrc/group_gemm_sm90.cu | 70 +++++ python/flashinfer/gemm.py | 85 ++++-- python/flashinfer/jit/__init__.py | 32 +-- python/flashinfer/jit/env.py | 5 +- python/flashinfer/utils.py | 6 + python/setup.py | 1 + tests/test_group_gemm.py | 3 +- 14 files changed, 528 insertions(+), 56 deletions(-) create mode 100644 flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu create mode 100644 include/flashinfer/gemm/group_gemm_sm90.cuh create mode 100644 python/csrc/flashinfer_gemm_sm90_ops.cu create mode 100644 python/csrc/group_gemm_sm90.cu diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index ee2091bfa..c9f0313f3 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once #include void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, diff --git a/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu new file mode 100644 index 000000000..5140982f4 --- /dev/null +++ b/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + + +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); +} \ No newline at end of file diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py index 396cf334d..5c5410268 100644 --- a/flashinfer-aot/setup.py +++ b/flashinfer-aot/setup.py @@ -355,6 +355,7 @@ def __init__(self, *args, **kwargs) -> None: include_dirs = [ str(root.resolve() / "include"), str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm + str(root.resolve() / "3rdparty" / "cutlass" / "tools" / "util" / "include"), ] extra_compile_args = { "cxx": [ @@ -371,6 +372,10 @@ def __init__(self, *args, **kwargs) -> None: "-use_fast_math", ], } + extra_compile_args_sm90 = extra_compile_args.copy() + extra_compile_args_sm90["nvcc"].extend( + "-gencode arch=compute_90a,code=sm_90a".split() + ) ext_modules = [] ext_modules.append( torch_cpp_ext.CUDAExtension( @@ -385,12 +390,23 @@ def __init__(self, *args, **kwargs) -> None: "csrc/quantization.cu", "csrc/group_gemm.cu", "csrc/bmm_fp8.cu", - "csrc_aot/flashinfer_ops.cu", + "csrc_aot/flashinfer_ops.cu" ], include_dirs=include_dirs, extra_compile_args=extra_compile_args, ) ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._kernels_sm90", + sources=[ + "csrc/group_gemm_sm90.cu", + "csrc_aot/flashinfer_sm90_ops.cu", + ], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args_sm90, + ) + ) ext_modules.append( torch_cpp_ext.CUDAExtension( name="flashinfer._decode_kernels", diff --git a/include/flashinfer/gemm/group_gemm.cuh b/include/flashinfer/gemm/group_gemm.cuh index 968662f97..20fca5515 100644 --- a/include/flashinfer/gemm/group_gemm.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, // so I just use the kernel function directly, need to investigate more. - auto compute_args_kernel = compute_cutlass_group_gemm_args; + auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args; compute_args_kernel<<>>( problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w, (DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); @@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe } // namespace flashinfer -#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_ +#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_ \ No newline at end of file diff --git a/include/flashinfer/gemm/group_gemm_cutlass.cuh b/include/flashinfer/gemm/group_gemm_cutlass.cuh index a3422bef9..0f71fa3db 100644 --- a/include/flashinfer/gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/gemm/group_gemm_cutlass.cuh @@ -16,11 +16,16 @@ #ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ #define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ +#include +#include +#include + #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" +#include "cutlass/util/packed_stride.hpp" namespace flashinfer { @@ -41,21 +46,49 @@ struct cutlass_dtype { using type = cutlass::bfloat16_t; }; -template -__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x, - T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w, - int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr, - int64_t* w_indices, size_t d_in, size_t d_out, - bool w_column_major) { +template <> +struct cutlass_dtype<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; + +template +__global__ void compute_sm80_cutlass_group_gemm_args( + cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr, + int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y, + int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { int i = blockIdx.x; int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); - ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out; - ptr_x[i] = x + xy_indptr[i] * d_in; - ptr_y[i] = y + xy_indptr[i] * d_out; - ld_x[i] = k; // m * k - ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major - ld_y[i] = n; // m * n + w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n; + x_ptr[i] = x + xy_indptr[i] * k; + y_ptr[i] = y + xy_indptr[i] * n; + x_ld[i] = k; // m * k + w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major + y_ld[i] = n; // m * n +} + +template +__global__ void compute_sm90_cutlass_group_gemm_args( + ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr, + StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y, + int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { + int i = blockIdx.x; + int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; + all_problems[i] = ProblemShape(m, n, k); + w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n; + x_ptr[i] = x + xy_indptr[i] * k; + y_ptr[i] = y + xy_indptr[i] * n; + + x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1}) + : cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1}); } } // namespace group_gemm diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh new file mode 100644 index 000000000..5d660a074 --- /dev/null +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -0,0 +1,248 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ +#define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ + +#include + +#include "group_gemm_cutlass.cuh" + +#include "../allocator.h" +#include "../utils.cuh" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + + +namespace flashinfer { + +namespace group_gemm { + +using namespace cute; + +#define DISPATCH_WEIGHT_LAYOUT(is_column_major, WEIGHT_LAYOUT, ...) \ + if (is_column_major) { \ + using WEIGHT_LAYOUT = cutlass::layout::ColumnMajor; \ + __VA_ARGS__ \ + } else { \ + using WEIGHT_LAYOUT = cutlass::layout::RowMajor; \ + __VA_ARGS__ \ + } + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template +cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_size_in_bytes, + void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x, + DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, + int64_t* w_indices_d, unsigned int batch_size, + unsigned int d_in, unsigned int d_out, + bool weight_column_major, cudaStream_t stream) { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first < 9) { + std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0" + << std::endl; + return cudaErrorNotSupported; + } else { + // Compute capability >= 9.0 + // Reference implementation + // - + // https://github.com/NVIDIA/cutlass/blob/f7b19de32c5d1f3cedfc735c2849f12b537522ee/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu + using ProblemShape = + cutlass::gemm::GroupProblemShape>; // per group + using ElementA = DTypeIn; // Element type for A matrix operand + using ElementB = DTypeIn; // Element type for B matrix operand + using ElementC = DTypeOut; // Element type for C and D matrix operands + + DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { + if constexpr (std::is_same_v && + sizeof(DTypeIn) == 1) { + std::ostringstream err_msg; + err_msg << "Row-major layout is not supported for fp8 data type"; + throw std::runtime_error(err_msg.str()); + } else { + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of + // elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = WEIGHT_LAYOUT; // Layout type for B matrix operand + constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of + // elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands + constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of + // elements (up to 16 bytes) + + constexpr bool is_fp8 = sizeof(DTypeIn) == 1; + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the + // intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = + typename std::conditional, + Shape<_128, _64, _64>>::type; // Threadblock-level tile size + using ClusterShape = + typename std::conditional, Shape<_2, _1, _1>>:: + type; // Shape of the threadblocks in a cluster + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = typename std::conditional< + is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type; // Kernel to launch + using EpilogueSchedule = + cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using DeviceGemmReference = + cutlass::reference::device::Gemm; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + AlignedAllocator allocator(int_buffer, int_buffer_size_in_bytes); + ProblemShape::UnderlyingProblemShape* problem_sizes_device = + allocator.aligned_alloc( + batch_size * sizeof(ProblemShape::UnderlyingProblemShape), 16, + "problem_sizes_device"); + DTypeIn** x_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "x_data"); + DTypeIn** w_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "w_data"); + DTypeOut** y_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeOut*), 16, "y_data"); + StrideA* x_stride = + allocator.aligned_alloc(batch_size * sizeof(StrideA), 16, "x_stride"); + StrideB* w_stride = + allocator.aligned_alloc(batch_size * sizeof(StrideB), 16, "w_stride"); + StrideC* y_stride = + allocator.aligned_alloc(batch_size * sizeof(StrideC), 16, "y_stride"); + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::EpilogueOutputOp::Params params; + // TODO(Zihao): support block alpha and beta + params = typename Gemm::EpilogueOutputOp::Params(/*alpha=*/ElementAccumulator(1.f), + /*beta=*/ElementAccumulator(0.f)); + + typename Gemm::Arguments arguments; + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {int(batch_size), problem_sizes_device, nullptr}, + {const_cast(x_data), x_stride, const_cast(w_data), + w_stride}, + {params, const_cast(y_data), y_stride, y_data, y_stride}, + hw_info}; + + compute_sm90_cutlass_group_gemm_args<<>>( + problem_sizes_device, x_data, w_data, y_data, x_stride, w_stride, y_stride, (DTypeIn*)x, + (DTypeIn*)w, (DTypeOut*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "Failed to launch compute_sm90_cutlass_group_gemm_args kernel: " + << cudaGetErrorString(err) << std::endl; + return err; + } + + // Initialize the gemm kernel + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix multiplication + // computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes); + auto workspace_ptr = float_allocator.aligned_alloc(workspace_size, 64, + "sm90_group_gemm_float_workspace"); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr)); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); // Warmup + } + }); + } + + return cudaSuccess; +} + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ diff --git a/python/csrc/flashinfer_gemm_sm90_ops.cu b/python/csrc/flashinfer_gemm_sm90_ops.cu new file mode 100644 index 000000000..0332eb8db --- /dev/null +++ b/python/csrc/flashinfer_gemm_sm90_ops.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + + +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, + "Cutlass Segment GEMM operator for SM90"); +} diff --git a/python/csrc/group_gemm_sm90.cu b/python/csrc/group_gemm_sm90.cu new file mode 100644 index 000000000..a218f347c --- /dev/null +++ b/python/csrc/group_gemm_sm90.cu @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer::group_gemm; + +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major) { + // TODO(Zihao): Add more checks here + CHECK_INPUT(seg_indptr); + CHECK_INPUT(x); + CHECK_INPUT(weight); + auto device = x.device(); + CHECK_EQ(seg_indptr.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, x); // x: [sum(m_i), d_in] + CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights, + // d_in, d_out] otherwise + int64_t cumulative_batch_size = x.size(0); + int64_t d_out = weight_column_major ? weight.size(1) : weight.size(2); + int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1); + CHECK_EQ(x.size(1), d_in); + auto y = torch::zeros({cumulative_batch_size, d_out}, x.options()); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + seg_indptr = seg_indptr.to(torch::kInt64); + + bool weight_indices_defined = weight_indices.numel() > 0; + if (weight_indices_defined) { + CHECK_INPUT(weight_indices); + CHECK_EQ(weight_indices.device(), device); + weight_indices = weight_indices.to(torch::kInt64); + } + + // TODO(Zihao): add fp8 support + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + using cutlass_t = typename cutlass_dtype::type; + auto status = CutlassSegmentGEMMSM90Run( + float_workspace_buffer.data_ptr(), + float_workspace_buffer.element_size() * float_workspace_buffer.size(0), + int_workspace_buffer.data_ptr(), + int_workspace_buffer.element_size() * int_workspace_buffer.size(0), + static_cast(x.data_ptr()), static_cast(weight.data_ptr()), + static_cast(y.data_ptr()), static_cast(seg_indptr.data_ptr()), + weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, + batch_size, d_in, d_out, weight_column_major, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); + return true; + }); + + return y; +} \ No newline at end of file diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index e7e7a0515..5b938765d 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -18,12 +18,13 @@ import torch -from .utils import get_indptr +from .utils import get_indptr, get_compute_capability from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops from typing import Optional _gemm_module = None +_gemm_module_sm90 = None def get_gemm_module(): @@ -37,14 +38,34 @@ def get_gemm_module(): _gemm_module = load_cuda_ops( "gemm", [ - FLASHINFER_CSRC_DIR / "group_gemm.cu", FLASHINFER_CSRC_DIR / "bmm_fp8.cu", + FLASHINFER_CSRC_DIR / "group_gemm.cu", FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", ], ) return _gemm_module +def get_gemm_sm90_module(): + print("get_gemm_sm90_module") + global _gemm_module_sm90 + if _gemm_module_sm90 is None: + if has_prebuilt_ops: + from . import _kernels_sm90 + + _gemm_module_sm90 = _kernels_sm90 + else: + _gemm_module_sm90 = load_cuda_ops( + "gemm_sm90", + [ + FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu", + ], + extra_cuda_cflags=["-gencode", "arch=compute_90a,code=sm_90a"], + ) + return _gemm_module_sm90 + + class SegmentGEMMWrapper: r"""Wrapper for segment GEMM kernels. @@ -53,7 +74,7 @@ class SegmentGEMMWrapper: >>> import torch >>> from flashinfer import SegmentGEMMWrapper >>> # create a 1MB workspace buffer - >>> workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") >>> segment_gemm = SegmentGEMMWrapper(workspace_buffer) >>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") >>> # create packed input tensor (10 = 1 + 2 + 3 + 4) @@ -96,27 +117,34 @@ class SegmentGEMMWrapper: True """ - def __init__(self, workspace_buffer: torch.Tensor) -> None: + def __init__(self, float_workspace_buffer: torch.Tensor) -> None: r"""Initialize the wrapper. Parameters ---------- - workspace_buffer : torch.Tensor - The workspace buffer for the kernels, we use it to store the metadata for the segment GEMM whose - size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases. + float_workspace_buffer : torch.Tensor + The workspace buffer for the kernels, we use it for storing intermediate results in cutlass + segment GEMM kernels. Encouraged size is 128MB. """ - self._workspace_buffer = workspace_buffer + self._int_workspace_buffer = torch.empty( + (1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device + ) + self._float_workspace_buffer = float_workspace_buffer - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should - be the same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The new float workspace buffer for the kernels. + int_workspace_buffer : torch.Tensor + The new int workspace buffer for the kernels. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer def run( self, @@ -193,15 +221,28 @@ def run( if weight_indices is None: # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) - return get_gemm_module().cutlass_segment_gemm( - self._workspace_buffer, - seg_indptr, - weight_indices, - x, - weights, - batch_size, - weight_column_major, - ) + major, _ = get_compute_capability(x.device) + if major >= 9: + return get_gemm_sm90_module().cutlass_segment_gemm_sm90( + self._float_workspace_buffer, + self._int_workspace_buffer, + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) + else: + return get_gemm_module().cutlass_segment_gemm( + self._int_workspace_buffer, + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) forward = run diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index bbd85729c..e0e272b82 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -17,15 +17,16 @@ import os import re import logging +import subprocess import torch.utils.cpp_extension as torch_cpp_ext -from typing import List +from typing import List, Tuple from .env import ( FLASHINFER_WORKSPACE_DIR, FLASHINFER_JIT_DIR, FLASHINFER_GEN_SRC_DIR, FLASHINFER_INCLUDE_DIR, FLASHINFER_CSRC_DIR, - CUTLASS_INCLUDE_DIR, + CUTLASS_INCLUDE_DIRS, ) from .activation import get_act_and_mul_cu_str, gen_act_and_mul_cu from .attention import ( @@ -111,22 +112,24 @@ def remove_unwanted_pytorch_nvcc_flags(): def load_cuda_ops( name: str, sources: List[str], - extra_cflags: List[str] = ["-O3", "-Wno-switch-bool"], - extra_cuda_cflags: List[str] = [ + extra_cflags: List[str] = [], + extra_cuda_cflags: List[str] = [], + extra_ldflags=None, + extra_include_paths=None, + verbose=False, +): + cflags = ["-O3", "-Wno-switch-bool"] + cuda_cflags = [ "-O3", "-std=c++17", "--threads", "4", - # "-Xfatbin", - # "-compress-all", "-use_fast_math", "-DFLASHINFER_ENABLE_BF16", "-DFLASHINFER_ENABLE_FP8", - ], - extra_ldflags=None, - extra_include_paths=None, - verbose=False, -): + ] + cflags += extra_cflags + cuda_cflags += extra_cuda_cflags logger.info(f"Loading JIT ops: {name}") check_cuda_arch() build_directory = FLASHINFER_JIT_DIR / name @@ -135,14 +138,13 @@ def load_cuda_ops( if extra_include_paths is None: extra_include_paths = [ FLASHINFER_INCLUDE_DIR, - CUTLASS_INCLUDE_DIR, FLASHINFER_CSRC_DIR, - ] + ] + CUTLASS_INCLUDE_DIRS return torch_cpp_ext.load( name, list(map(lambda _: str(_), sources)), - extra_cflags=extra_cflags, - extra_cuda_cflags=extra_cuda_cflags, + extra_cflags=cflags, + extra_cuda_cflags=cuda_cflags, extra_ldflags=extra_ldflags, extra_include_paths=list(map(lambda _: str(_), extra_include_paths)), build_directory=build_directory, diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py index 65b47eed0..e3fbec818 100644 --- a/python/flashinfer/jit/env.py +++ b/python/flashinfer/jit/env.py @@ -23,4 +23,7 @@ _project_root = pathlib.Path(__file__).resolve().parent.parent.parent FLASHINFER_INCLUDE_DIR = _project_root / "include" FLASHINFER_CSRC_DIR = _project_root / "csrc" -CUTLASS_INCLUDE_DIR = _project_root / "3rdparty" / "cutlass" / "include" +CUTLASS_INCLUDE_DIRS = [ + _project_root / "3rdparty" / "cutlass" / "include", + _project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", +] diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index e8e40ee5c..acaed4a00 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -176,3 +176,9 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: raise TypeError( "dtype must be a string or torch.dtype, got {}".format(type(dtype)) ) + + +def get_compute_capability(device: torch.device) -> Tuple[int, int]: + if device.type != "cuda": + raise ValueError("device must be a cuda device") + return torch.cuda.get_device_capability(device.index) diff --git a/python/setup.py b/python/setup.py index ffb3debb2..52166d513 100644 --- a/python/setup.py +++ b/python/setup.py @@ -46,6 +46,7 @@ def clear_aot_config(): if __name__ == "__main__": generate_build_meta() + clear_aot_config() setuptools.setup( name="flashinfer", version=get_version(), diff --git a/tests/test_group_gemm.py b/tests/test_group_gemm.py index 96c48fb88..fb35c0a30 100644 --- a/tests/test_group_gemm.py +++ b/tests/test_group_gemm.py @@ -84,8 +84,7 @@ def test_segment_gemm( ), ), rtol=1e-3, - atol=1e-3, - msg="assertion failed at batch {}".format(i), + atol=1e-3 ) else: torch.testing.assert_close(