From bbe579a9e3beb6ea6626d9227ec32d0dae119a49 Mon Sep 17 00:00:00 2001 From: ANIKET SHIVAM <3268307+ANIKET-SHIVAM@users.noreply.github.com> Date: Thu, 15 Feb 2024 12:48:34 -0800 Subject: [PATCH] Updates for CUTLASS 3.4.1 (#1346) * Updates for CUTLASS 3.4.1 * minor epi change --- CHANGELOG.md | 9 +- CMakeLists.txt | 27 ++- PUBLICATIONS.md | 3 +- README.md | 13 +- cmake/version.h.in | 38 --- cmake/version_extended.h.in | 34 +++ examples/02_dump_reg_shmem/CMakeLists.txt | 1 + .../56_hopper_ptr_array_batched_gemm.cu | 14 +- .../CMakeLists.txt | 18 +- .../57_hopper_grouped_gemm.cu | 145 +++++++---- .../57_hopper_grouped_gemm/CMakeLists.txt | 10 + include/cute/arch/copy_sm90_desc.hpp | 2 +- include/cute/atom/mma_atom.hpp | 2 +- include/cute/util/type_traits.hpp | 3 + include/cutlass/arch/mma_sm90.h | 4 + include/cutlass/detail/layout.hpp | 36 ++- .../collective/builders/sm90_builder.inl | 5 +- .../epilogue/collective/default_epilogue.hpp | 1 + .../collective/default_epilogue_array.hpp | 50 ++-- .../sm90_epilogue_tma_warpspecialized.hpp | 17 +- include/cutlass/epilogue/dispatch_policy.hpp | 3 +- .../sm90_callbacks_tma_warpspecialized.hpp | 5 +- .../epilogue/thread/linear_combination.h | 69 +++++- .../collective/builders/sm90_gmma_builder.inl | 12 +- ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 74 +++--- include/cutlass/gemm/dispatch_policy.hpp | 11 +- .../gemm/group_array_problem_shape.hpp | 12 + ..._array_tma_warpspecialized_cooperative.hpp | 65 +++-- include/cutlass/gemm/kernel/sm90_gemm_tma.hpp | 12 +- .../kernel/sm90_gemm_tma_warpspecialized.hpp | 12 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 12 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 12 +- .../gemm/kernel/sm90_gemm_warpspecialized.hpp | 12 +- .../sm90_gemm_warpspecialized_cooperative.hpp | 12 +- .../sm90_gemm_warpspecialized_pingpong.hpp | 12 +- .../gemm/kernel/sm90_tile_scheduler_group.hpp | 226 +++++++++++------- .../gemm/kernel/tile_scheduler_params.h | 20 +- include/cutlass/version.h | 80 +++++++ pyproject.toml | 4 +- python/cutlass/__init__.py | 6 +- python/cutlass/backend/c_types.py | 8 +- python/cutlass/backend/epilogue.py | 24 +- .../backend/evt/frontend/frontend_base.py | 4 +- .../backend/evt/passes/graph_drawer.py | 16 -- python/cutlass/backend/gemm_operation.py | 46 ++-- python/setup_library.py | 2 +- python/setup_pycute.py | 2 +- ...r_warpspecialized_cooperative_aux_store.cu | 43 ++-- .../include/cutlass/util/packed_stride.hpp | 1 - 49 files changed, 799 insertions(+), 450 deletions(-) delete mode 100644 cmake/version.h.in create mode 100644 cmake/version_extended.h.in create mode 100644 include/cutlass/version.h diff --git a/CHANGELOG.md b/CHANGELOG.md index dbe0a3386e..6e00866683 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # NVIDIA CUTLASS Changelog -## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12) +## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14) + +- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side. +- Improvements for Hopper [Group-GEMMs](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm). +- Updates and bugfixes from the community (thanks!). + +## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12) * Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors. * Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) * Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above). @@ -8,7 +14,6 @@ * NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released. * Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved. - ## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31) * [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. * [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}. diff --git a/CMakeLists.txt b/CMakeLists.txt index 114d793616..ed75907329 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,25 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 3.4.0 LANGUAGES CXX) +# To reduce duplicate version locations, parse the version out of the +# main versions.h file and reuse it here. + +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/include/cutlass/version.h VERSION_FILE_CONTENTS) +string(REGEX MATCH "#define CUTLASS_MAJOR ([0-9]+)" _CUTLASS_VERSION_MAJOR "${VERSION_FILE_CONTENTS}") +set(_CUTLASS_VERSION_MAJOR ${CMAKE_MATCH_1}) +string(REGEX MATCH "#define CUTLASS_MINOR ([0-9]+)" _CUTLASS_VERSION_MINOR "${VERSION_FILE_CONTENTS}") +set(_CUTLASS_VERSION_MINOR ${CMAKE_MATCH_1}) +string(REGEX MATCH "#define CUTLASS_PATCH ([0-9]+)" _CUTLASS_VERSION_PATCH "${VERSION_FILE_CONTENTS}") +set(_CUTLASS_VERSION_PATCH ${CMAKE_MATCH_1}) + +message(STATUS "CUTLASS ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH}") + +## CUTLASS PROJECT ############################################################# + +project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH} LANGUAGES CXX) + +################################################################################ + include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -178,6 +196,9 @@ if(WIN32) set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE) endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_VERSIONS_GENERATED") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_VERSIONS_GENERATED") + if (WIN32) # Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors. list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3) @@ -589,8 +610,8 @@ if (NOT DEFINED CUTLASS_REVISION) endif() configure_file( - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in - ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in + ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h @ONLY) target_include_directories( diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index a69dd70b27..487309f777 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -2,7 +2,8 @@ ## 2023 -- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi and Jay Shah. _arXiv_, December 2023. +- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023. + - ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023. diff --git a/README.md b/README.md index 462649fc67..d642ee7250 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # CUTLASS 3.4 -_CUTLASS 3.4 - January 2024_ +_CUTLASS 3.4 - February 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -43,13 +43,18 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im # What's New in CUTLASS 3.4 +CUTLASS 3.4.1 is an update to CUTLASS adding: +- Statically available [CUTLASS Version macros](/include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side. +- Improvements for Hopper [Group-GEMM](/examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMM](/examples/56_hopper_ptr_array_batched_gemm). +- Updates and bugfixes from the community (thanks!). + CUTLASS 3.4.0 is an update to CUTLASS adding: - Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100. - Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above) - Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above) - [Ampere Sparse GEMM](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. -- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library. +- Improvements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library. - Improved [CuTe documentation](/media/docs/cute/) including improved clarity and depth of [Quickstart](/media/docs/cute/00_quickstart.md), [CuTe Layout](/media/docs/cute/01_layout.md), and [CuTe Layout Algebra](/media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](/test/unit/cute/core/) also improved. Minimum requirements: @@ -93,8 +98,8 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA # Compatibility CUTLASS requires a C++17 host compiler and -performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive). -It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1 +performs best when built with the [**CUDA 12.3.2 Toolkit**](https://developer.nvidia.com/cuda-downloads). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2. ## Operating Systems We have tested the following environments. diff --git a/cmake/version.h.in b/cmake/version.h.in deleted file mode 100644 index 1b48e1abc2..0000000000 --- a/cmake/version.h.in +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include - -#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@ -#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@ -#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@ -#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@ -#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) - -namespace cutlass { - - inline uint32_t getVersion() { - return CUTLASS_VERSION; - } - inline uint32_t getVersionMajor() { - return CUTLASS_MAJOR; - } - inline uint32_t getVersionMinor() { - return CUTLASS_MINOR; - } - inline uint32_t getVersionPatch() { - return CUTLASS_PATCH; - } - inline uint32_t getVersionBuild() { - return CUTLASS_BUILD + 0; - } - inline std::string getVersionString() { - std::string version = "@CUTLASS_VERSION@"; - if (getVersionBuild()) { - version += "." + std::to_string(getVersionBuild()); - } - return version; - } - inline std::string getGitRevision() { - return "@CUTLASS_REVISION@"; - } - -} // namespace cutlass diff --git a/cmake/version_extended.h.in b/cmake/version_extended.h.in new file mode 100644 index 0000000000..3613063022 --- /dev/null +++ b/cmake/version_extended.h.in @@ -0,0 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@ +#define CUTLASS_REVISION "@CUTLASS_REVISION@" diff --git a/examples/02_dump_reg_shmem/CMakeLists.txt b/examples/02_dump_reg_shmem/CMakeLists.txt index a6fce01e86..0216f2b480 100644 --- a/examples/02_dump_reg_shmem/CMakeLists.txt +++ b/examples/02_dump_reg_shmem/CMakeLists.txt @@ -31,4 +31,5 @@ cutlass_example_add_executable( 02_dump_reg_shmem dump_reg_shmem.cu + DISABLE_TESTS ON ) diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index e02e0a2851..465c0a41f2 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -70,7 +70,7 @@ using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations @@ -98,8 +98,8 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -169,7 +169,7 @@ cutlass::DeviceAllocation ptr_C; cutlass::DeviceAllocation ptr_D; cutlass::DeviceAllocation ptr_ref_D; -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -245,7 +245,7 @@ struct Result bool passed = false; }; -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -468,7 +468,7 @@ int run(Options &options) return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -510,7 +510,7 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) run(options); #endif diff --git a/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt index cf02a48d89..0a4e69566a 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt +++ b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt @@ -27,17 +27,17 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. -# Only the correctness check will be run by these commands. +set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes +set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes -set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes -set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes +set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=1) # Default problem sizes +set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes -set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes -set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes +set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Default problem sizes w/ Epilogue Op test +set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test -set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes -set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes +set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes +set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=1) # Small-k problem sizes cutlass_example_add_executable( 56_hopper_ptr_array_batched_gemm @@ -47,6 +47,8 @@ cutlass_example_add_executable( TEST_SQUARE_LARGE_BATCH TEST_EPILOGUE TEST_EPILOGUE_LARGE_BATCH + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_BATCH TEST_SMALLK TEST_SMALLK_LARGE_BATCH ) diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index 2985eb7746..2a737e1e98 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -44,6 +44,7 @@ The above example command makes all 10 groups to be sized at the given m, n, k sizes. Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. To run this example for a set of problems using the benchmark option: @@ -62,6 +63,7 @@ #include #include #include +#include #include "cutlass/cutlass.h" @@ -91,9 +93,9 @@ using namespace cute; using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand -using ElementC = float; // Element type for C and D matrix operands +using ElementC = cutlass::half_t; // Element type for C and D matrix operands -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations @@ -101,40 +103,40 @@ using ElementC = float; // Element type // A matrix configuration using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand -constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) // B matrix configuration using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand -constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) // C/D matrix configuration using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands -constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size -using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using ClusterShape = Shape<_2,_2,_1>; // Shape of the threadblocks in a cluster using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch -using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, - ElementC, LayoutC, AlignmentC, - ElementC, LayoutC, AlignmentC, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, EpilogueSchedule >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< @@ -161,10 +163,10 @@ using DeviceGemmReference = cutlass::reference::device::Gemm< ElementAccumulator, ElementAccumulator>; -using StrideA = typename Gemm::GemmKernel::StrideA; -using StrideB = typename Gemm::GemmKernel::StrideB; -using StrideC = typename Gemm::GemmKernel::StrideC; -using StrideD = typename Gemm::GemmKernel::StrideD; +using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA; +using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB; +using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC; +using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD; // Host-side allocations std::vector offset_A; @@ -177,6 +179,9 @@ std::vector stride_B_host; std::vector stride_C_host; std::vector stride_D_host; +std::vector alpha_host; +std::vector beta_host; + // Device-side allocations cutlass::DeviceAllocation problem_sizes; @@ -197,7 +202,13 @@ cutlass::DeviceAllocation stride_B; cutlass::DeviceAllocation stride_C; cutlass::DeviceAllocation stride_D; -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -208,8 +219,8 @@ struct Options { bool help = false; - float alpha = 1.0f; - float beta = 0.0f; + float alpha = FLT_MAX; + float beta = FLT_MAX; int iterations = 10; int m = 1024, n = 2048, k = 512, groups = 10; std::string benchmark_path; @@ -230,8 +241,8 @@ struct Options { cmd.get_cmd_line_argument("n", n); cmd.get_cmd_line_argument("k", k); cmd.get_cmd_line_argument("groups", groups); - cmd.get_cmd_line_argument("alpha", alpha, 1.f); - cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("benchmark", benchmark_path); @@ -248,10 +259,7 @@ struct Options { } void randomize_problems(cutlass::CommandLine &cmd) { - int cmd_line_m = -1; - int cmd_line_n = -1; - int cmd_line_k = -1; - + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; cmd.get_cmd_line_argument("m", cmd_line_m); cmd.get_cmd_line_argument("n", cmd_line_n); cmd.get_cmd_line_argument("k", cmd_line_k); @@ -259,19 +267,15 @@ struct Options { problem_sizes_host.reserve(groups); for (int i = groups; i > 0; i--) { - int m = cmd_line_m; int n = cmd_line_n; int k = cmd_line_k; - if (m < 1) { m = ((rand() % 512) + 1); } - if (n < 1) { n = ((rand() % 512) + 1); } - if (k < 1) { k = alignment * ((rand() % 64) + 1); } @@ -317,6 +321,7 @@ struct Options { problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); } } + groups = static_cast(problem_sizes_host.size()); return true; } @@ -351,7 +356,9 @@ struct Options { uint64_t fmas = uint64_t(); for (auto const & problem : problem_sizes_host) { - fmas += cute::size(problem); + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); } // Two flops per multiply-add uint64_t flop = uint64_t(2) * uint64_t(fmas); @@ -370,7 +377,7 @@ struct Result bool passed = false; }; -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -435,6 +442,7 @@ void allocate(const Options &options) { stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}))); stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}))); stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}))); + } block_A.reset(total_elements_A); @@ -442,6 +450,8 @@ void allocate(const Options &options) { block_C.reset(total_elements_C); block_D.reset(total_elements_D); block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); } /// Initialize operands to be used in the GEMM and reference GEMM @@ -460,12 +470,18 @@ void initialize(const Options &options) { std::vector ptr_B_host(options.groups); std::vector ptr_C_host(options.groups); std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); for (int32_t i = 0; i < options.groups; ++i) { ptr_A_host.at(i) = block_A.get() + offset_A.at(i); ptr_B_host.at(i) = block_B.get() + offset_B.at(i); ptr_C_host.at(i) = block_C.get() + offset_C.at(i); ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; } ptr_A.reset(options.groups); @@ -492,13 +508,20 @@ void initialize(const Options &options) { stride_D.reset(options.groups); stride_D.copy_from_host(stride_D_host.data()); + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + initialize_block(block_A, seed + 2023); initialize_block(block_B, seed + 2022); initialize_block(block_C, seed + 2021); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) +typename Gemm::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true) { cutlass::KernelHardwareInfo hw_info; // Change device_id to another value if you are running on a machine with multiple GPUs and wish @@ -506,13 +529,36 @@ typename Gemm::Arguments args_from_options(const Options &options) hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, - {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, - {{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, - hw_info - }; + typename Gemm::EpilogueOutputOp::Params params; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + params = typename Gemm::EpilogueOutputOp::Params( + ElementAccumulator(options.alpha), ElementAccumulator(options.beta)); + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + params = typename Gemm::EpilogueOutputOp::Params(alpha_device.get(), beta_device.get()); + } + + typename Gemm::Arguments arguments; + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {params, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } return arguments; } @@ -539,10 +585,10 @@ bool verify(const Options &options) { // Launch device reference gemm kernel gemm_reference( {M, N, K}, - ElementAccumulator(options.alpha), + ElementAccumulator(alpha_host.at(i)), ref_A, ref_B, - ElementAccumulator(options.beta), + ElementAccumulator(beta_host.at(i)), ref_C, ref_D); @@ -560,7 +606,7 @@ bool verify(const Options &options) { /// Execute a given example GEMM computation template -int run(Options &options) +int run(Options &options, bool host_problem_shapes_available = true) { allocate(options); initialize(options); @@ -569,7 +615,7 @@ int run(Options &options) Gemm gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options); + auto arguments = args_from_options(options, host_problem_shapes_available); // Using the arguments, query for extra workspace required for matrix multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -612,12 +658,12 @@ int run(Options &options) result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); - std::cout << " Problem Sizes: " << std::endl; - for (auto const & problem : options.problem_sizes_host) { - std::cout << " " << problem << std::endl; + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; } std::cout << " Groups : " << options.groups << std::endl; - std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl; std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS : " << result.gflops << std::endl; } @@ -625,7 +671,7 @@ int run(Options &options) return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -667,8 +713,9 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) run(options); + run(options, false /*host_problem_shapes_available*/); #endif return 0; diff --git a/examples/57_hopper_grouped_gemm/CMakeLists.txt b/examples/57_hopper_grouped_gemm/CMakeLists.txt index 170331b5b2..2c3ff3a496 100644 --- a/examples/57_hopper_grouped_gemm/CMakeLists.txt +++ b/examples/57_hopper_grouped_gemm/CMakeLists.txt @@ -35,9 +35,15 @@ set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_OP --beta=0.7 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes + set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes + set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes @@ -49,8 +55,12 @@ cutlass_example_add_executable( TEST_RANDOM_LARGE_GROUP TEST_EPILOGUE TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP TEST_FIXED TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP TEST_RANDOM_PERF TEST_RANDOM_PERF_LARGE_GROUP ) diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 8cad74488a..9fa4f34b24 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -265,7 +265,7 @@ tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor asm volatile ( "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" :: "l"(smem_int64_desc), "r"(prob_shape[2])); - // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 asm volatile ( "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index ae606153fe..098efd4af8 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -391,7 +391,7 @@ struct TiledMMA : MMA_Atom } else { return cute::max(core_size, perm_size); } - + CUTE_GCC_UNREACHABLE; } diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index a46e1c4df8..3107b81161 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -125,6 +125,9 @@ using CUTE_STL_NAMESPACE::invoke_result_t; using CUTE_STL_NAMESPACE::common_type; using CUTE_STL_NAMESPACE::common_type_t; +using CUTE_STL_NAMESPACE::remove_pointer; +using CUTE_STL_NAMESPACE::remove_pointer_t; + // using CUTE_STL_NAMESPACE::declval; diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index 35d6d70474..d2b167a7ce 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -64,6 +64,10 @@ #endif #endif +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED +#endif + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index f46d9c1404..f3b9963850 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -80,6 +80,40 @@ struct TagToStrideB { using tag = layout::ColumnMajor; }; +// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default +// Used by pointer array and grouped gemm +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, int64_t>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, int64_t, int64_t>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, int64_t, int64_t>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, int64_t>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + // Maps to modes [M, N, L] template struct TagToStrideC : TagToStrideA { }; @@ -101,7 +135,7 @@ template constexpr bool is_major(Stride = {}) { // Account for stride types with and without batch mode and batch modes with static zero stride - return cute::is_constant<1, decltype(cute::front(cute::get(Stride{})))>::value; + return cute::is_constant<1, decltype(cute::front(cute::get(cute::remove_pointer_t{})))>::value; } // Note : This method can be used for deducing the Layout Tag of A, C, D Matrices diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index c5f148d41c..f1c47f4400 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -268,7 +268,7 @@ struct Sm90TmaBuilderImpl { // Passing void C disables source load + smem allocation using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; - + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; @@ -434,8 +434,7 @@ struct CollectiveBuilder< Schedule, fusion::LinearCombination, cute::enable_if_t || - cute::is_same_v || - cute::is_same_v >> { + cute::is_same_v >> { // Passing void C disables source load using ElementC = cute::conditional_t, diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index e9951b1731..aad96de74a 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -63,6 +63,7 @@ class DefaultEpilogue { // Type Aliases // using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; // derived types of output thread level operator using ThreadEpilogueOp = ThreadEpilogueOp_; diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp index de699395c1..5807f87eb8 100644 --- a/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -73,12 +73,10 @@ class DefaultEpilogueArray { using ElementScalar = ElementCompute; using ElementC = typename ThreadEpilogueOp::ElementC; using StrideC = StrideC_; + using UnderlyingStrideC = cute::remove_pointer_t; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - using StridesC = cute::conditional_t, - StrideC const*, StrideC>; - using StridesD = cute::conditional_t, - StrideD const*, StrideD>; + using UnderlyingStrideD = cute::remove_pointer_t; using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -86,10 +84,9 @@ class DefaultEpilogueArray { static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(cute::is_same_v || - cute::is_same_v, "Incompatible epilogue schedule."); - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(rank(UnderlyingStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(UnderlyingStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { }; @@ -97,9 +94,9 @@ class DefaultEpilogueArray { struct Arguments { typename ThreadEpilogueOp::Params thread{}; ElementC const** ptr_C = nullptr; - StridesC dC{}; + StrideC dC{}; ElementD** ptr_D = nullptr; - StridesD dD{}; + StrideD dD{}; }; // Device side epilogue params @@ -140,12 +137,13 @@ class DefaultEpilogueArray { CUTLASS_HOST_DEVICE DefaultEpilogueArray(Params const& params_) - : params(params_), epilogue_op(params_.thread) { } + : params(params_) { } CUTLASS_DEVICE bool is_source_needed() { - return epilogue_op.is_source_needed(); + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; } template< @@ -185,10 +183,23 @@ class DefaultEpilogueArray { // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - StrideC stride_c; - StrideD stride_d; - if constexpr (cute::is_same_v) { - stride_c = detail::get_epilogue_stride(params.dC[l_coord]); + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + UnderlyingStrideC stride_c; + UnderlyingStrideD stride_d; + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + stride_c = detail::get_epilogue_stride(params.dC[l_coord]); + } stride_d = detail::get_epilogue_stride(params.dD[l_coord]); } else { @@ -197,7 +208,11 @@ class DefaultEpilogueArray { } // Represent the full output tensor - Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l) + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) @@ -242,7 +257,6 @@ class DefaultEpilogueArray { private: Params params; - ThreadEpilogueOp epilogue_op; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 1d904b0581..a463463b58 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -148,12 +148,12 @@ class CollectiveEpilogue< constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); - + using EmptyType = cute::tuple<>; - using SmemCStorage = cute::conditional_t, EmptyType>; - using SmemDStorage = cute::conditional_t, EmptyType>; @@ -189,6 +189,7 @@ class CollectiveEpilogue< struct SharedStorage { using TensorStorage = TensorStorageImpl; + TensorStorage tensors; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -249,12 +250,12 @@ class CollectiveEpilogue< Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC)); tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0)); } - + typename Params::TMA_D tma_store_d; if constexpr (is_destination_supported) { Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0)); - } + } return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), @@ -385,13 +386,13 @@ class CollectiveEpilogue< // Apply epilogue subtile, get matching smem tensor SmemElementC* ptr_sC = nullptr; - + if constexpr (is_source_supported) { if constexpr (ReuseSmemC) { ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); } else { ptr_sC = shared_tensors.smem_C().data(); - } + } } Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) @@ -559,7 +560,7 @@ class CollectiveEpilogue< // Vectorized fragment view constexpr int FragmentSize = DispatchPolicy::FragmentSize; Tensor tRS_rAcc_frg = recast>(tRS_rAcc); - Tensor tRS_rD_frg = recast>(tRS_rD); + Tensor tRS_rD_frg = recast>(tRS_rD); CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); // (t)hread-partition for (s)mem to (r)egister copy (tSR_) diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 3a9eb85c41..2d36314e8f 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -46,8 +46,7 @@ namespace cutlass::epilogue { ////////////////////////////////////////////////////////////////////////////// struct NoSmemWarpSpecialized {}; -struct NoSmemWarpSpecializedArray {}; -struct NoSmemWarpSpecializedGroup {}; +struct PtrArrayNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; // DEPRECATED schedules, will be removed in next release diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 84ca9e35eb..fec0cd505f 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -1247,6 +1247,7 @@ struct FusionCallbacks< }; ///////////////////////////////////////////////////////////////////////////////////////////////// + namespace detail { template > struct get_element_aux { @@ -1257,7 +1258,7 @@ template struct get_element_aux> { using type = typename FusionOpOrCallbacks::ElementAux; }; - + template struct get_element_aux, cute::void_t<>> { using type = typename get_element_aux::type; @@ -1270,7 +1271,7 @@ struct get_element_aux, cute::void_t::type; }; -} +} // namespace cutlass:epilogue::fusion::detail template using get_element_aux_t = typename detail::get_element_aux::type; diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index f17878b47b..f74a36af4b 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -88,43 +88,72 @@ class LinearCombination { /// Host-constructable parameters structure struct Params { - ElementCompute alpha; ///< scales accumulators - ElementCompute beta; ///< scales source tensor - ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory - ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + ElementCompute const* const* alpha_ptr_array; ///< array of pointers to accumulator scalar per group/batch + ElementCompute const* const* beta_ptr_array; ///< array of pointers to source scalar per group/batch CUTLASS_HOST_DEVICE Params(): alpha(ElementCompute(1)), beta(ElementCompute(0)), alpha_ptr(nullptr), - beta_ptr(nullptr) { } + beta_ptr(nullptr), + alpha_ptr_array(nullptr), + beta_ptr_array(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha, ElementCompute beta ): - alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { } + alpha(alpha), beta(beta), + alpha_ptr(nullptr), beta_ptr(nullptr), + alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute alpha ): - alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { } + alpha(alpha), beta(0), + alpha_ptr(nullptr), beta_ptr(nullptr), + alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr ): - alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } + alpha(0), beta(0), + alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), + alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { } CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr ): - alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { } + alpha(0), beta(0), + alpha_ptr(alpha_ptr), beta_ptr(nullptr), + alpha_ptr_array(nullptr), beta_ptr_array(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const* const* alpha_ptr_array, + ElementCompute const* const* beta_ptr_array + ): + alpha(0), beta(0), + alpha_ptr(nullptr), beta_ptr(nullptr), + alpha_ptr_array(alpha_ptr_array), beta_ptr_array(beta_ptr_array) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const* const* alpha_ptr_array + ): + alpha(0), beta(0), + alpha_ptr(nullptr), beta_ptr(nullptr), + alpha_ptr_array(alpha_ptr_array), beta_ptr_array(nullptr) { } }; private: @@ -140,9 +169,25 @@ class LinearCombination { /// Constructs the function object, possibly loading from pointers in host memory CUTLASS_HOST_DEVICE - LinearCombination(Params const ¶ms) { - alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); - beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + LinearCombination(Params const ¶ms, int group_idx = 0) { + if (params.alpha_ptr_array != nullptr && params.alpha_ptr_array[group_idx] != nullptr) { + alpha_ = *(params.alpha_ptr_array[group_idx]); + } + else if (params.alpha_ptr != nullptr) { + alpha_ = *params.alpha_ptr; + } + else { + alpha_ = params.alpha; + } + if (params.beta_ptr_array != nullptr && params.beta_ptr_array[group_idx] != nullptr) { + beta_ = *(params.beta_ptr_array[group_idx]); + } + else if (params.beta_ptr != nullptr) { + beta_ = *params.beta_ptr; + } + else { + beta_ = params.beta; + } } /// Returns true if source is needed diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 557173c776..082ef9157f 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -185,8 +185,7 @@ struct CollectiveBuilder< (cute::is_same_v || cute::is_same_v || cute::is_same_v || - cute::is_same_v || - cute::is_same_v) && + cute::is_same_v) && not detail::is_use_rmem_A()> > { static_assert(is_static::value); @@ -197,8 +196,7 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); - static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v || - cute::is_same_v); + static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); static constexpr bool IsFP8Input = detail::is_input_fp8(); static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); @@ -515,8 +513,7 @@ struct CollectiveBuilder< cute::is_same_v || cute::is_same_v || cute::is_same_v || - cute::is_same_v || - cute::is_same_v> + cute::is_same_v> > { static_assert(is_static::value); static_assert(is_static::value); @@ -534,8 +531,7 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v || - cute::is_same_v); + static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v); using AtomLayoutMNK = cute::conditional_t || IsArrayOfPointersGemm, Layout>, Layout>>; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index c3ebccae8b..5ae843076d 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -93,8 +93,10 @@ struct CollectiveMma< using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; + using UnderlyingStrideA = cute::remove_pointer_t; using ElementB = ElementB_; using StrideB = StrideB_; + using UnderlyingStrideB = cute::remove_pointer_t; using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; @@ -149,14 +151,14 @@ struct CollectiveMma< // Assumption: StrideA is congruent with Problem_MK using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + make_tensor(static_cast(nullptr), repeat_like(UnderlyingStrideA{}, int32_t(0)), UnderlyingStrideA{}), SmemLayoutA{}(_,_,cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + make_tensor(static_cast(nullptr), repeat_like(UnderlyingStrideB{}, int32_t(0)), UnderlyingStrideB{}), SmemLayoutB{}(_,_,cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any @@ -179,16 +181,14 @@ struct CollectiveMma< using TensorMapStorage = typename SharedStorage::TensorMapStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; - static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v; - using StridesA = cute::conditional_t; - using StridesB = cute::conditional_t; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; // Host side kernel arguments struct Arguments { ElementA const** ptr_A; - StridesA dA; + StrideA dA; ElementB const** ptr_B; - StridesB dB; + StrideB dB; }; // Device side kernel params @@ -197,9 +197,9 @@ struct CollectiveMma< TMA_B tma_load_b; void* tensormaps; InternalElementA const** ptr_A; - StridesA dA; + StrideA dA; InternalElementB const** ptr_B; - StridesB dB; + StrideB dB; }; // @@ -212,30 +212,36 @@ struct CollectiveMma< ProblemShape problem_shapes, Arguments const& args, void* workspace) { - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1); - auto [M,N,K,L] = problem_shape_MNKL; + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + // Batches/Groups are managed by using appropriate pointers to input matrices const uint32_t mock_L = 1; - - // These tensor pointers are only used to create tensormap/tma desc. - // This address to the tensor will be replaced with correct address before the initial tma load InternalElementA const* ptr_A_first_batch = reinterpret_cast(args.ptr_A); InternalElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B); - cudaError_t cuda_error = cudaGetLastError(); // clear previous error - StrideA stride_a; - StrideB stride_b; + UnderlyingStrideA stride_a; + UnderlyingStrideB stride_b; if constexpr (IsGroupedGemmKernel) { - // Strides for Grouped Gemm will be replaced prior to the first access regardless - stride_a = StrideA{}; - stride_b = StrideB{}; + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = UnderlyingStrideA{}; + stride_b = UnderlyingStrideB{}; } else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + stride_a = args.dA; stride_b = args.dB; } - Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a)); - Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b)); + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,mock_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,mock_L), stride_b)); TMA_A tma_load_a = make_tma_copy( GmemTiledCopyA{}, tensor_a, @@ -287,12 +293,14 @@ struct CollectiveMma< constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; bool implementable = true; - // Check alignment for all problem sizes - for (int i = 0; i < problem_shapes.groups(); i++) { - auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); - auto [M,N,K,L] = problem_shape_MNKL; - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); - implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), UnderlyingStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), UnderlyingStrideB{}); + } } if (!implementable) { @@ -676,6 +684,14 @@ struct CollectiveMma< cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A, prob_shape_A, prob_stride_A); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 9f3225c4f6..c8ba9ba109 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -53,8 +53,7 @@ struct KernelTma { }; struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperative { }; -struct KernelArrayTmaWarpSpecializedCooperative { }; -struct KernelGroupTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedCooperative { }; ////////////////////////////////////////////////////////////////////////////// @@ -67,8 +66,7 @@ struct KernelGroupTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; -struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { }; -struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { }; +struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { }; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -233,7 +231,7 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8 template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative > struct MainloopSm90ArrayTmaGmmaWarpSpecialized { constexpr static int Stages = Stages_; @@ -241,8 +239,7 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized { using ArchTag = arch::Sm90; using Schedule = KernelSchedule; static_assert( - cute::is_base_of_v || - cute::is_base_of_v, + cute::is_base_of_v, "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); }; diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp index 5f1c162915..4a90a1d06d 100644 --- a/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -71,6 +71,12 @@ struct GroupProblemShape { get_host_problem_shape(int32_t group_idx) const { return host_problem_shapes[group_idx]; } + + CUTLASS_HOST_DEVICE + bool + is_host_problem_shape_available() { + return host_problem_shapes != nullptr; + } }; template @@ -104,6 +110,12 @@ class ArrayProblemShape { get_host_problem_shape(int32_t /* unused */ = 0) const { return problem_shape_; } + + CUTLASS_HOST_DEVICE + bool + is_host_problem_shape_available() { + return true; + } private: UnderlyingProblemShape problem_shape_{}; }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 17c2f12142..dd816b0ea1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -62,8 +62,7 @@ class GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, - cute::enable_if_t || - cute::is_base_of_v> + cute::enable_if_t> > { public: @@ -80,7 +79,9 @@ class GemmUniversal< using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; using StrideA = typename CollectiveMainloop::StrideA; + using UnderlyingStrideA = typename CollectiveMainloop::UnderlyingStrideA; using ElementB = typename CollectiveMainloop::ElementB; + using UnderlyingStrideB = typename CollectiveMainloop::UnderlyingStrideB; using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using Schedule = typename DispatchPolicy::Schedule; @@ -93,8 +94,10 @@ class GemmUniversal< using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; using StrideC = typename CollectiveEpilogue::StrideC; + using UnderlyingStrideC = typename CollectiveEpilogue::UnderlyingStrideC; using ElementD = typename CollectiveEpilogue::ElementD; using StrideD = typename CollectiveEpilogue::StrideD; + using UnderlyingStrideD = typename CollectiveEpilogue::UnderlyingStrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; @@ -102,7 +105,7 @@ class GemmUniversal< static_assert(cute::is_void_v, "Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler."); - static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; using TileScheduler = cute::conditional_t( - args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups); + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* epilogue_workspace = workspace_ptr + workspace_offset; @@ -244,14 +247,11 @@ class GemmUniversal< bool can_implement(Arguments const& args) { bool implementable = true; - if constexpr (cute::is_base_of_v) { - implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); - } else if constexpr (IsGroupedGemmKernel) { + if constexpr (IsGroupedGemmKernel) { // Group GEMM currently only supports rank-3 problem shapes implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); - } - else { - implementable = false; + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); } if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); @@ -269,7 +269,7 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); @@ -297,9 +297,9 @@ class GemmUniversal< constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; @@ -350,23 +350,20 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else // Preconditions static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); static_assert(size<0>(TileShape{}) >= 128, "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(UnderlyingStrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(UnderlyingStrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(UnderlyingStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(UnderlyingStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ enum class WarpGroupRole { @@ -441,8 +438,6 @@ class GemmUniversal< // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - // Purpose of maintaining this pipeline state is to make sure TMA loads have finished before doing descriptor updates - typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state; // For the DMA Load (producer) we start with an opposite phase // i.e., we skip all waits since we know that the buffer is indeed empty @@ -554,7 +549,8 @@ class GemmUniversal< shared_storage.tensors.mainloop ); // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); + // Wait for the last TMA stage to complete loading, before issuing tensormap updates + mainloop_pipe_producer_state.advance(work_k_tile_count - 1); // Signal for the epilogue load warp to begin if (do_load_order_arrive) { @@ -570,8 +566,10 @@ class GemmUniversal< if constexpr (IsGroupedGemmKernel) { problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{}); } - // Wait for the last TMA stage to complete loading, before issuing tensormap updates - mainloop_pipe_tma_consumer_state.advance(work_k_tile_count-1); + // Purpose of this pipeline state is to make sure TMA loads have finished before doing descriptor updates + // Since this state is waiting for loads to finish, it must start in the inverted phase. + typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state = + {mainloop_pipe_producer_state.index(), !mainloop_pipe_producer_state.phase(), mainloop_pipe_producer_state.count()}; mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state); collective_mainloop.tensormaps_perform_update( shared_storage.tensormaps.mainloop, @@ -585,13 +583,9 @@ class GemmUniversal< // Entire warp must do this (ie its aligned) collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); curr_batch = next_batch; - // Advance the TMA consumer state for the last remaining stage that was being waited for above - mainloop_pipe_tma_consumer_state.advance(1); - } - else if (work_tile_info.is_valid()) { // case where batch/group didn't change between tiles - // Advance the TMA consumer state for all the stages to be in sync - mainloop_pipe_tma_consumer_state.advance(work_k_tile_count); } + // Advance the producer state for the last remaining stage that was being waited for above + mainloop_pipe_producer_state.advance(1); } // Scheduler work fetch loop // Make sure all Consumer Warp Groups have been waited upon @@ -720,6 +714,7 @@ class GemmUniversal< ); } } // Consumer Warp Groups End +#endif } private: diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index a2c6015b36..a9694a2c4c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -211,13 +211,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else // Preconditions static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -311,6 +308,7 @@ class GemmUniversal< thread_idx, smem_buf ); +#endif } }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index a272d630f8..558d6379f3 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -219,13 +219,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else enum class WarpGroupRole { Producer = 0, @@ -435,6 +432,7 @@ class GemmUniversal< epi_store_pipe_producer_state_next ); } +#endif } }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 708d91563c..f0928edb14 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -298,13 +298,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else // Preconditions static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); @@ -610,6 +607,7 @@ class GemmUniversal< ); } } // Consumer Warp Groups End +#endif } private: diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 9f7221d89d..46a865b0aa 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -296,13 +296,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else // Preconditions static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -612,6 +609,7 @@ class GemmUniversal< work_tile_info = scheduler.get_current_work(); } // Scheduler work fetch loop } // Consumer Warp Groups End +#endif } }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index 54f6e9617d..293cd67fd0 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -223,13 +223,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else enum class WarpGroupRole { Producer = 0, @@ -409,6 +406,7 @@ class GemmUniversal< shared_storage.tensors.epilogue ); } +#endif } }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 14c9be4f60..2b6ff5880d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -250,13 +250,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -493,6 +490,7 @@ class GemmUniversal< ); } } // Consumer Warp Groups End +#endif } private: diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index 40c925ac48..1ae7ccaa78 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -257,13 +257,10 @@ class GemmUniversal< using namespace cute; using X = Underscore; - // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. - #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); - return; - } - #endif +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else // Preconditions static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); @@ -509,6 +506,7 @@ class GemmUniversal< work_tile_info = scheduler.get_current_work(); } // Scheduler work fetch loop } // Consumer Warp Groups End +#endif } }; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index bb0a5b486c..ce431dafec 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -55,7 +55,7 @@ class PersistentTileSchedulerSm90Group { // Tracking current group, its starting linear idx and total tiles struct GroupInfo { - uint64_t group = 0; + int group_idx = 0; uint64_t start_linear_idx = 0; uint64_t total_tiles = 0; } current_group_info_; @@ -115,7 +115,7 @@ class PersistentTileSchedulerSm90Group { GroupProblemShape problem_shapes, TileShape tile_shape, ClusterShape cluster_shape, - [[maybe_unused]] KernelHardwareInfo const& hw_info, + KernelHardwareInfo const& hw_info, Arguments const& arguments, [[maybe_unused]] void* workspace=nullptr, [[maybe_unused]] const uint32_t epilogue_subtile = 1) { @@ -126,14 +126,16 @@ class PersistentTileSchedulerSm90Group { dim3 problem_blocks = get_tiled_cta_shape_mnl( problem_shapes.groups(), - reinterpret_cast(problem_shapes.host_problem_shapes), + problem_shapes, + hw_info, tile_shape, cluster_shape); Params params; params.initialize( problem_blocks, problem_shapes.groups(), - reinterpret_cast(problem_shapes.problem_shapes), + problem_shapes.problem_shapes, + problem_shapes.host_problem_shapes, to_gemm_coord(tile_shape), to_gemm_coord(cluster_shape), hw_info, @@ -144,6 +146,64 @@ class PersistentTileSchedulerSm90Group { return params; } + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + GroupProblemShape problem_shapes, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments, + bool truncate_by_problem_size=true) { + + dim3 problem_blocks = get_tiled_cta_shape_mnl( + problem_shapes.groups(), + problem_shapes, + hw_info, + tile_shape, cluster_shape); + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order, + /* truncate_by_problem_size = */true + ); + } + + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { + uint32_t total_ctas = 0; + uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here + + // If host problem shapes are not provided. + if (!problem_shapes.is_host_problem_shape_available()) { + total_ctas = hw_info.sm_count; + } + // If host problem shapes are provided, make a better decision about possibility to launch smaller grid. + else { + for (int group = 0; group < groups; group++) { + auto ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes.get_host_problem_shape(group)), cute::shape<0>(cta_shape))); + auto ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes.get_host_problem_shape(group)), cute::shape<1>(cta_shape))); + auto problem_blocks_m = round_up(ctas_along_m, cute::get<0>(cluster_shape)); + auto problem_blocks_n = round_up(ctas_along_n, cute::get<1>(cluster_shape)); + total_ctas += problem_blocks_m * problem_blocks_n; + } + } + + return Params::get_tiled_cta_shape_mnl( + to_gemm_coord(cluster_shape), + total_ctas, cta_in_N_dim + ); + } + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { @@ -156,7 +216,7 @@ class PersistentTileSchedulerSm90Group { // MSVC requires protecting use of CUDA-specific nonstandard syntax, // like blockIdx and gridDim, with __CUDA_ARCH__. #if defined(__CUDA_ARCH__) - if (params_.raster_order_ == RasterOrder::AlongN) { + if (scheduler_params.raster_order_ == RasterOrder::AlongN) { current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); } else { @@ -165,9 +225,19 @@ class PersistentTileSchedulerSm90Group { total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); - auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), params_.cta_shape_.m())); - auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), params_.cta_shape_.n())); - current_group_info_.total_tiles = cta_m * cta_n; + uint64_t ctas_along_m, ctas_along_n; + if (is_tuple(params_.problem_shapes_[0]))>::value || + is_tuple(params_.problem_shapes_[0]))>::value) { + ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.m())); + ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.n())); + } + else { + ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_m_.divisor - 1); + ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_n_.divisor - 1); + } + auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m()); + auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n()); + current_group_info_.total_tiles = problem_blocks_m * problem_blocks_n; #else CUTLASS_ASSERT(false && "This line should never be reached"); #endif @@ -182,24 +252,22 @@ class PersistentTileSchedulerSm90Group { CUTLASS_DEVICE WorkTileInfo get_current_work_for_linear_idx(uint64_t linear_idx) { - if (linear_idx >= scheduler_params.blocks_per_problem_) { + if (scheduler_params.pre_processed_problem_shapes && linear_idx >= scheduler_params.blocks_across_problem_) { return WorkTileInfo::invalid_work_tile(); } - uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(linear_idx); - - auto [work_idx_m, work_idx_n, new_group_info, valid_tile] = get_work_idx_m_and_n(blk_per_grid_dim, - current_group_info_, - scheduler_params.groups_, - scheduler_params.problem_shapes_, - scheduler_params.cta_shape_, - scheduler_params.divmod_cluster_shape_major_, - scheduler_params.divmod_cluster_shape_minor_, - scheduler_params.log_swizzle_size_, - scheduler_params.raster_order_); - - current_group_info_ = new_group_info; - return {work_idx_m, work_idx_n, static_cast(current_group_info_.group), valid_tile}; + return get_work_idx_m_and_n(linear_idx, + current_group_info_, + scheduler_params.groups_, + scheduler_params.problem_shapes_, + scheduler_params.cta_shape_, + scheduler_params.cluster_shape_, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cta_shape_m_, + scheduler_params.divmod_cta_shape_n_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); } CUTLASS_DEVICE @@ -208,34 +276,62 @@ class PersistentTileSchedulerSm90Group { current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); } - // get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle + // get work_idx_m, work_idx_n from linear_idx while applying swizzle static CUTLASS_DEVICE - cute::tuple + WorkTileInfo get_work_idx_m_and_n( - uint64_t blk_per_grid_dim, - struct GroupInfo group_info, + uint64_t linear_idx, + struct GroupInfo& group_info, int32_t total_problem_groups, ProblemShape* problem_shapes, GemmCoord cta_shape, + GemmCoord cluster_shape, FastDivmodU64Pow2 const& divmod_cluster_shape_major, FastDivmodU64Pow2 const& divmod_cluster_shape_minor, + FastDivmodU64 const& divmod_cta_shape_m, + FastDivmodU64 const& divmod_cta_shape_n, int32_t log_swizzle_size, RasterOrder raster_order) { bool valid_tile = true; - int cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m())); - int cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n())); + uint64_t ctas_along_m, ctas_along_n; + if (is_tuple(problem_shapes[group_info.group_idx]))>::value || + is_tuple(problem_shapes[group_info.group_idx]))>::value) { + ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m())); + ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n())); + } + else { + ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1); + } + auto problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; + + while (group_info.start_linear_idx + group_info.total_tiles <= linear_idx) { + group_info.group_idx++; + + if (group_info.group_idx >= total_problem_groups) + return WorkTileInfo::invalid_work_tile(); - while (group_info.start_linear_idx + group_info.total_tiles <= blk_per_grid_dim) { - group_info.group++; group_info.start_linear_idx += group_info.total_tiles; - cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m())); - cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n())); - group_info.total_tiles = cta_m * cta_n; + if (is_tuple(problem_shapes[group_info.group_idx]))>::value || + is_tuple(problem_shapes[group_info.group_idx]))>::value) { + ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group_idx]), cta_shape.m())); + ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group_idx]), cta_shape.n())); + } + else { + ctas_along_m = divmod_cta_shape_m.divide(cute::shape<0>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_m.divisor - 1); + ctas_along_n = divmod_cta_shape_n.divide(cute::shape<1>(problem_shapes[group_info.group_idx]) + divmod_cta_shape_n.divisor - 1); + } + problem_blocks_m = round_up(ctas_along_m, (1 << log_swizzle_size) * cluster_shape.m()); + problem_blocks_n = round_up(ctas_along_n, (1 << log_swizzle_size) * cluster_shape.n()); + group_info.total_tiles = problem_blocks_m * problem_blocks_n; } uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; - divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim - group_info.start_linear_idx); + uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx); + divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); if (raster_order == RasterOrder::AlongN) { @@ -252,8 +348,13 @@ class PersistentTileSchedulerSm90Group { offset = cluster_id & ((1 << log_swizzle_size) - 1); extra = cluster_id >> log_swizzle_size; - uint64_t curr_group_cluster_blk_major, remainder; - divmod_cluster_shape_major(curr_group_cluster_blk_major, remainder, cta_m); + uint64_t curr_group_cluster_blk_major; + if (raster_order == RasterOrder::AlongN) { + curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_n); + } + else { + curr_group_cluster_blk_major = divmod_cluster_shape_major.divide(problem_blocks_m); + } cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; cluster_idx_major = extra % curr_group_cluster_blk_major; @@ -265,61 +366,14 @@ class PersistentTileSchedulerSm90Group { cluster_major_offset); if (raster_order == RasterOrder::AlongN) { - return {minor_work_idx, major_work_idx, group_info, valid_tile}; + return {minor_work_idx, major_work_idx, group_info.group_idx, valid_tile}; } else { - return {major_work_idx, minor_work_idx, group_info, valid_tile}; + return {major_work_idx, minor_work_idx, group_info.group_idx, valid_tile}; } } - // Given the inputs, computes the total number of output blocks this problem will compute over - // Note that this is only the logical size of our grid, not the physical grid we will actually launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_tiled_cta_shape_mnl(int groups, ProblemShape const* problem_shapes, BlockShape cta_shape, ClusterShape cluster_shape) { - uint32_t total_ctas = 0; - uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here - for (int group = 0; group < groups; group++) { - auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group]), cute::shape<0>(cta_shape))); - auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group]), cute::shape<1>(cta_shape))); - total_ctas += cta_m * cta_n; - } - - return Params::get_tiled_cta_shape_mnl( - to_gemm_coord(cluster_shape), - total_ctas, cta_in_N_dim - ); - } - - // Given the inputs, computes the physical grid we should launch. - template - CUTLASS_HOST_DEVICE static - dim3 - get_grid_shape( - GroupProblemShape problem_shapes, - BlockShape cta_shape, - ClusterShape cluster_shape, - KernelHardwareInfo hw_info, - Arguments arguments, - bool truncate_by_problem_size=true) { - - dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes.groups(), - reinterpret_cast(problem_shapes.host_problem_shapes), - cta_shape, cluster_shape); - - return Params::get_grid_shape( - problem_blocks, - to_gemm_coord(cluster_shape), - hw_info, - arguments.max_swizzle_size, - arguments.raster_order, - /* truncate_by_problem_size = */true - ); - } - // Returns whether the block assigned this work should compute the epilogue for the corresponding // output tile. For the basic tile scheduler, this is always true. CUTLASS_HOST_DEVICE diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 97e5c5437f..25a95529f5 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -1273,15 +1273,18 @@ struct PersistentTileSchedulerSm90GroupParams { FastDivmodU64Pow2 divmod_cluster_shape_major_{}; FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; - FastDivmodU64 divmod_batch_{}; + FastDivmodU64 divmod_cta_shape_m_{}; + FastDivmodU64 divmod_cta_shape_n_{}; - uint64_t blocks_per_problem_ = 0; + uint64_t blocks_across_problem_ = 0; + bool pre_processed_problem_shapes = true; int32_t log_swizzle_size_ = 0; RasterOrder raster_order_ = RasterOrder::AlongN; int32_t groups_ = 0; ProblemShape* problem_shapes_ = nullptr; GemmCoord cta_shape_; + GemmCoord cluster_shape_; // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, @@ -1291,6 +1294,7 @@ struct PersistentTileSchedulerSm90GroupParams { dim3 problem_blocks, int32_t groups, ProblemShape* problem_shapes, + ProblemShape const* host_problem_shapes, GemmCoord cta_shape, GemmCoord cluster_shape, KernelHardwareInfo const& hw_info, @@ -1317,11 +1321,12 @@ struct PersistentTileSchedulerSm90GroupParams { groups_ = groups; problem_shapes_ = problem_shapes; cta_shape_ = cta_shape; + cluster_shape_ = cluster_shape; - blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z; + blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z; + pre_processed_problem_shapes = (host_problem_shapes == nullptr) ? false : true; log_swizzle_size_ = log_swizzle_size; raster_order_ = raster_order; - divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n); if (raster_order == RasterOrder::AlongN) { divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n()); @@ -1331,6 +1336,9 @@ struct PersistentTileSchedulerSm90GroupParams { divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m()); divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n()); } + + divmod_cta_shape_m_ = FastDivmodU64(cta_shape_.m()); + divmod_cta_shape_n_ = FastDivmodU64(cta_shape_.n()); } // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. @@ -1344,8 +1352,8 @@ struct PersistentTileSchedulerSm90GroupParams { auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n(); return { - static_cast(problem_blocks_m), - static_cast(problem_blocks_n), + static_cast(cta_m), + static_cast(cta_n), static_cast(1) // Only a single batch per group is currently supported }; } diff --git a/include/cutlass/version.h b/include/cutlass/version.h new file mode 100644 index 0000000000..ca0f052907 --- /dev/null +++ b/include/cutlass/version.h @@ -0,0 +1,80 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#define CUTLASS_MAJOR 3 +#define CUTLASS_MINOR 4 +#define CUTLASS_PATCH 1 + +#ifdef CUTLASS_VERSIONS_GENERATED +#include "cutlass/version_extended.h" +#else +#define CUTLASS_BUILD 0 +#define CUTLASS_REVISION "" +#endif + +#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) + +namespace cutlass { + + inline constexpr uint32_t getVersion() { + return CUTLASS_VERSION; + } + inline constexpr uint32_t getVersionMajor() { + return CUTLASS_MAJOR; + } + inline constexpr uint32_t getVersionMinor() { + return CUTLASS_MINOR; + } + inline constexpr uint32_t getVersionPatch() { + return CUTLASS_PATCH; + } + inline constexpr uint32_t getVersionBuild() { + return CUTLASS_BUILD + 0; + } + + inline std::string getVersionString() { + std::string version = "@CUTLASS_VERSION@"; + if (getVersionBuild()) { + version += "." + std::to_string(getVersionBuild()); + } + return version; + } + + inline std::string getGitRevision() { + return "@CUTLASS_REVISION@"; + } + +} // namespace cutlass diff --git a/pyproject.toml b/pyproject.toml index 3a91d3943f..16536f3561 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "3.4.0.0" +version = "3.4.1.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" -license = {file = "LICENSE.txt"} +license = {text = "BSD-3-Clause"} classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: BSD License", diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index e29f503b6c..cfa9af2533 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -40,7 +40,7 @@ def _cuda_install_path_from_nvcc() -> str: import subprocess # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC - result = subprocess.run(['which', 'nvcc'], capture_output=True) + result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True) if result.returncode != 0: raise Exception(f'Unable to find nvcc via `which` utility.') @@ -121,7 +121,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '3.4.0' +this.__version__ = '3.4.1' from cutlass.backend import create_memory_pool from cutlass.emit.pytorch import pytorch @@ -169,7 +169,7 @@ def initialize_cuda_context(): raise Exception("No CUDA devices found") device_id = 0 - this._device_id = device_id + this._device_id = int(device_id) def device_id() -> int: diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 331594d6d1..a0035b6786 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -213,8 +213,12 @@ def from_generic_mainloop_args(args: GenericMainloopArguments3x_): return _MainloopArgumentsTma -def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args): - _EpilogueOutputOpParams = epilogue_functor.epilogue_type +def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue): + if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt + else: + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + if hasattr(epilogue_functor, "visitor"): class _EpilogueArguments(ctypes.Structure): _fields_ = [ diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass/backend/epilogue.py index 214a094208..b7bdc34854 100644 --- a/python/cutlass/backend/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -157,19 +157,41 @@ def __init__( c_element_epilogue = dtype2ctype[self.element_epilogue] element_epilogue = self.element_epilogue + class _EpilogueOutputOpParamsEVT(ctypes.Structure): + """ + Epilogue params when using the default linear combination of EVT, which + does not currently use {alpha,beta}_ptr_array + """ + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + class _EpilogueOutputOpParams(ctypes.Structure): _fields_ = [ ("alpha", c_element_epilogue), ("beta", c_element_epilogue), ("alpha_ptr", ctypes.c_void_p), - ("beta_ptr", ctypes.c_void_p) + ("beta_ptr", ctypes.c_void_p), + ("alpha_ptr_array", ctypes.c_void_p), + ("beta_ptr_array", ctypes.c_void_p), ] def __init__(self, alpha, beta, *args) -> None: self.alpha = to_ctype_value(alpha, element_epilogue) self.beta = to_ctype_value(beta, element_epilogue) + def to_evt_params(self) -> _EpilogueOutputOpParamsEVT: + return _EpilogueOutputOpParamsEVT(self.alpha, self.beta) + self.epilogue_type = _EpilogueOutputOpParams + self.epilogue_type_evt = _EpilogueOutputOpParamsEVT def emit(self): return super().emit(self.tag, self.template_arguments) diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass/backend/evt/frontend/frontend_base.py index bd283d7df6..5c63c14121 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass/backend/evt/frontend/frontend_base.py @@ -241,10 +241,10 @@ def visualize(self, name="dag_ir"): :param name: the name of the graph """ drawer = EVTGraphDrawer(self.dag_ir, name) - if drawer.dot_available: + try: for name, graph in drawer.get_dot_graph(): graph.write_svg(f"./{name}.svg") - else: + except: raise RuntimeError( "'dot' is not found in path. GraphDrawer is disabled. " "Please install it with 'sudo apt-get install graphviz'." diff --git a/python/cutlass/backend/evt/passes/graph_drawer.py b/python/cutlass/backend/evt/passes/graph_drawer.py index 0f4353a860..a2a73640c1 100644 --- a/python/cutlass/backend/evt/passes/graph_drawer.py +++ b/python/cutlass/backend/evt/passes/graph_drawer.py @@ -61,22 +61,6 @@ def __init__( self._dot_graphs = {} self._dot_graphs[name] = self._to_dot(graph, name) - self.dot_available = self._check_dot_availability() - - def _check_dot_availability(self): - """ - Check if graphviz is installed - """ - try: - # Run the 'dot' command and capture its output - result = subprocess.run( - ["dot", "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) - # Check if the command was successful and the output contains version information - if result.returncode == 0 and "dot - graphviz" in result.stderr: - return True - except FileNotFoundError: - pass - return False def _get_node_style(self, node): template = { diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index bfa83bb869..29ab291a8a 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -325,7 +325,7 @@ def get_arguments(self): def initialize(self): launch_config = self.operation.rt_module.plan(self) - # Get the host and evice workspace + # Get the host and device workspace device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) if device_workspace_size > 0: @@ -512,6 +512,18 @@ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalM super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) def get_arguments(self): + mainloop_args = get_mainloop_arguments_3x( + self.operation.tile_description.kernel_schedule, + self.operation.A.element, + self.operation.B.element, + self.operation.A.alignment, + self.operation.B.alignment + ) + scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler) + uses_default_epilogue = self.operation.rt_module.uses_default_epilogue() + argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x( + mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue) + problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count) if self.batch_count > 1: @@ -539,9 +551,12 @@ def get_arguments(self): ) # Set of mainloop arguments needed for this kernel - mainloop = self.operation.rt_module.mainloop_args.from_generic_mainloop_args(generic_args) + mainloop = mainloop_args.from_generic_mainloop_args(generic_args) + + if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"): + self.output_op = self.output_op.to_evt_params() - epilogue = self.operation.rt_module.epilogue_args( + epilogue = epilogue_args( self.output_op, int(self.ptr_C), stride_C, @@ -550,15 +565,15 @@ def get_arguments(self): ) # Set hardware info - hw_info = self.operation.rt_module.hw_info(0, device_sm_count()) + hw_info_ = hw_info(0, device_sm_count()) - self.arguments = self.operation.argument_type( + self.arguments = argument_type( int(self.gemm_mode), problem_size_, mainloop, epilogue, - hw_info, - self.operation.rt_module.scheduler_args + hw_info_, + scheduler_args ) return self.arguments @@ -1119,6 +1134,10 @@ class GemmRTUniversal3x(GemmRTUniversal): using GemmType = ${operation_name}_base; + bool ${operation_name}_uses_default_epilogue() { + return std::is_same_v; + } + // Get the workspace size uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) { return GemmType::get_workspace_size(*argument); @@ -1163,19 +1182,10 @@ def __init__(self, operation): "get_grid_shape": dim3_, "get_block_shape": dim3_, "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64, - "get_kernel_workspace_size": ctypes.c_uint64 + "get_kernel_workspace_size": ctypes.c_uint64, + "uses_default_epilogue": ctypes.c_bool, } self.emitter = EmitGemmUniversalInstance3x("_type") - self.mainloop_args = get_mainloop_arguments_3x( - operation.tile_description.kernel_schedule, - operation.A.element, - operation.B.element, - operation.A.alignment, - operation.B.alignment - ) - self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler) - self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x( - self.mainloop_args, operation.epilogue_functor, self.scheduler_args) def get_device_workspace_size(self, arguments: GemmArguments3x): return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) diff --git a/python/setup_library.py b/python/setup_library.py index 92f6430e8d..f95bf6df7c 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='cutlass_library', - version='3.4.0', + version='3.4.1', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index fb858fd8f7..625036680e 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='pycute', - version='3.4.0', + version='3.4.1', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu index 2b542f2b04..57b71757bb 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -56,9 +56,6 @@ #include "gemm_testbed_3x_evt.hpp" #include "sm90_evt_operations.hpp" - -#define CUTLASS_ARCH_MMA_SM90_SUPPORTED - #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) using namespace cute; @@ -132,7 +129,7 @@ bool testEVTAuxStoreWithoutD() { D_block.reset(m * n); aux_store_D_block.reset(m * n); Gemm gemm_op_base; - + auto stride_A = cutlass::make_cute_packed_stride( typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{})); auto stride_B = cutlass::make_cute_packed_stride( @@ -141,7 +138,7 @@ bool testEVTAuxStoreWithoutD() { typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{})); auto stride_D = cutlass::make_cute_packed_stride( typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{})); - + auto arguments_base = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, @@ -178,12 +175,12 @@ bool testEVTAuxStoreWithoutD() { /*hw_info=*/{}, /*scheduler_args=*/{} }; - + constexpr float beta [[maybe_unused]] = 1.0; constexpr float alpha [[maybe_unused]] = 1.0; - + using ElementC = typename GemmWithoutD::ElementC; - + if constexpr (not has_c) { arguments_base.epilogue.thread = { // binary op : alpha * acc @@ -282,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t >; - + using namespace cutlass::epilogue::fusion; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; @@ -324,10 +321,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 return *(GemmKernel *)(nullptr); }; - + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; @@ -352,7 +349,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t >; - + using namespace cutlass::epilogue::fusion; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; @@ -394,10 +391,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 return *(GemmKernel *)(nullptr); }; - + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; @@ -467,7 +464,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12 using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; @@ -492,7 +489,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t >; - + using namespace cutlass::epilogue::fusion; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; @@ -534,10 +531,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 return *(GemmKernel *)(nullptr); }; - + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; @@ -562,7 +559,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t >; - + using namespace cutlass::epilogue::fusion; constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; @@ -604,10 +601,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 return *(GemmKernel *)(nullptr); }; - + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; @@ -677,7 +674,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12 using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index 405da035bd..5802866682 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -35,7 +35,6 @@ #pragma once #include "cute/layout.hpp" -#include "cute/stride.hpp" /////////////////////////////////////////////////////////////////////////////////////////////////