From a8b75a44c309aa32867b6e8bf91a5c73e988b2ab Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 18 Jun 2024 20:33:25 -0400 Subject: [PATCH] [Bugfix] Fix w8a8 benchmarks for int8 case (#5643) --- benchmarks/cutlass_benchmarks/w8a8_benchmarks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 523e970c2c9be..5cc0fbbd49b8e 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -120,9 +120,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # cutlass impl timers.append( - bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"), - torch.bfloat16, label, sub_label, cutlass_impl, - "cutlass_i8_i8_bf16_scaled_mm")) + bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, + cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) return timers