Skip to content

Feat/blockwise fp8 quant #1668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
121 changes: 121 additions & 0 deletions benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3

def get_rowwise_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
assert A_nbits in (4, 8) and B_nbits in (4, 8)

dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C

def get_blockwise_problem(m: int, n: int, k: int, block_size: int):
assert n % block_size == 0 and k % block_size == 0, "N and K dims must be divisible by block_size"
dev = torch.device("cuda")
A = (448.0 * (2 * torch.rand(m, k, device=dev) - 1)).to(torch.float8_e4m3fn)
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=dev)
B = (448.0 * (2 * torch.rand(n, k, device=dev) - 1)).to(torch.float8_e4m3fn)
B_scale = torch.randn((n // block_size, k // block_size), dtype=torch.half, device=dev)

return A, A_scale, B, B_scale

def benchmark(m: int, k: int, n: int, block_size: int):
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k, 8, 8)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size)
blockwise_fp8_gemm_time = benchmark_microseconds(
blockwise_fp8_gemm, A, A_scale, B, B_scale
)

# Add precision tests
# On prend 2 sets de matrices aléatoires
# On les quantise en int8/int4 rowwise
# On les quantise en en float8 blockwise
#


return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
"rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
"blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time,
"blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time,
}


from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_weight_quant, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm

def test_quant_dequant():
torch.manual_seed(0)
x = torch.randn(256, 256).cuda()
qx, s = fp8_blockwise_weight_quant(x, block_size=128)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128)

error = torch.norm(x - x_reconstructed) / torch.norm(x)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quant-Dequant error too high!"

def test_blockwise_fp8_gemm():
torch.manual_seed(0)
M, N, K = 256, 256, 128
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()

C = A @ B.T

A_q, A_s = fp8_blockwise_act_quant(A, block_size=128)
B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128)

C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)

error = torch.norm(C - C_q) / torch.norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quantized GEMM error is too high!"


test_quant_dequant()
test_blockwise_fp8_gemm()


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)
block_size_vals = (128, 128, 128, 128)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k, block_size in zip(n_vals, k_vals, block_size_vals):
results.append(benchmark(m, k, n, block_size))

df = pd.DataFrame(results)
df.to_csv("blockwise_scaled_linear_triton_time_results.csv", index=False)
print(df.to_markdown(index=False))
7 changes: 5 additions & 2 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def run(
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
elif scaling_granularity == ScalingGranularity.AXISWISE:
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
else:
assert scaling_granularity == ScalingGranularity.BLOCKWISE, "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is benchmarking torch._scaled_mm which does not support blockwise scaling, is this change intended?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintended, but I will rework this PR. There were some details that I had missed when I initially worked on it.

scale_a = torch.ones(M, N, device=device)
scale_b = torch.ones(M, N, device=device)

def do_matmul(A, B):
nonlocal scale_a
Expand Down
25 changes: 21 additions & 4 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,30 @@ def run(
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# get the float8 dynamic blockwise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_BLOCKWISE)
m_fp8_dyn_blk = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_blk = torch.compile(m_fp8_dyn_blk)
fp8_dyn_blk_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_blk, x)

# get the lw_axs recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
# m_fp8_lw_axs = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw_axs = torch.compile(m_fp8_lw_axs)
# fp8_lw_axs_time_actual_s = get_gpu_kernel_time(m_fp8_lw_axs, x)

# get the lw_blk recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP)
# m_fp8_lw_blk = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw_blk = torch.compile(m_fp8_lw_blk)
# fp8_lw_blk_time_actual_s = get_gpu_kernel_time(m_fp8_lw_blk, x)

results.append(
[
Expand All @@ -382,6 +398,7 @@ def run(
fp8_dyn_time_actual_s,
fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
fp8_dyn_blk_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
Expand Down
Loading