diff --git a/include/flashinfer/group_gemm/wrapper.cuh b/include/flashinfer/group_gemm/wrapper.cuh index a7c5393d1..03f0b3eb5 100644 --- a/include/flashinfer/group_gemm/wrapper.cuh +++ b/include/flashinfer/group_gemm/wrapper.cuh @@ -85,7 +85,7 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape cutlass::epilogue::thread::LinearCombination, // Epilogue cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling Operator - 8 // Stages + 4 // Stages >::GemmKernel; using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;