From 7a81d8073f16e04b8943b578ebc1681e9e87d57d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 10 Oct 2024 02:16:17 -0400 Subject: [PATCH] [Bugfix] Machete garbage results for some models (large K dim) (#9212) Signed-off-by: Sumit Dubey --- .../quantization/machete/machete_mainloop.cuh | 23 +++++++++++-------- tests/kernels/test_machete_gemm.py | 5 ++-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 3d574ad99efda..e8e7b14de0da1 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -591,24 +591,27 @@ struct MacheteCollectiveMma { tma_load_b = make_tma_copy_B( make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB)); + int32_t scale_k = + (ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0; + int32_t group_size = (ModeHasScales) ? args.group_size : 0; + if constexpr (ModeHasScales) { - tma_load_scale = make_tma_copy_scale(make_logical_tensor( - args.ptr_S, make_shape(M, args.group_size, L), args.dS)); + tma_load_scale = make_tma_copy_scale( + make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS)); } if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - tma_load_zero = make_tma_copy_zero(make_logical_tensor( - args.ptr_Z, make_shape(M, args.group_size, L), args.dS)); + tma_load_zero = make_tma_copy_zero( + make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS)); } - if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { - return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0}; - } else if constexpr (ModeHasScales) { - auto scale_k = (K + args.group_size - 1) / args.group_size; - + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == + ConversionMode::ConvertAndScaleWithZero) { return {tma_load_a, tma_load_b, tma_load_scale, - tma_load_zero, scale_k, args.group_size}; + tma_load_zero, scale_k, group_size}; } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 0dfa79e9af8ec..0fc2984a68ded 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -24,13 +24,14 @@ (1, 128, 128), (1, 512, 1024), (1, 4096, 4096), + (1, 8192, 28672), (13, 8192, 4096), (26, 4096, 8192), - (1, 4096, 4096), + (64, 4096, 4096), + (64, 8192, 28672), (257, 128, 4096), (257, 4224, 4160), (257, 4096, 4096), - (64, 4096, 4096), (1024, 4096, 8192), (1024, 8192, 4096), ]