-
Notifications
You must be signed in to change notification settings - Fork 251
Feat/blockwise fp8 quant #1668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Degnel
wants to merge
10
commits into
pytorch:main
Choose a base branch
from
Degnel:feat/blockwise_fp8_quant
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Feat/blockwise fp8 quant #1668
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
55dab5f
Feat: blockwise fp8 quantizer
Degnel 5ab1eb2
Feat: fp8 linear layer with blockwise quantization
Degnel 4af3c34
Merge branch 'pytorch:main' into feat/blockwise_fp8_quant
Degnel aa7dc87
Feat: adding assertions in the ops file
Degnel 167fdce
Feat: adding some tests for blockwise fp8 quant
Degnel 9e9d16e
Fix: fixes for the blockwise_fp8_quantization
Degnel cf5802a
Merge branch 'pytorch:main' into feat/blockwise_fp8_quant
Degnel 6c9246a
linting
Degnel 89c6ed0
Feat/test: quant/dequant weight/act + test
Degnel 91b368d
linting
Degnel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import pandas as pd | ||
import torch | ||
from tqdm import tqdm | ||
from triton.testing import do_bench | ||
|
||
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm | ||
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 | ||
|
||
|
||
def benchmark_microseconds(f, *args): | ||
return do_bench(lambda: f(*args), return_mode="median") * 1e3 | ||
|
||
def get_rowwise_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): | ||
assert A_nbits in (4, 8) and B_nbits in (4, 8) | ||
|
||
dev = torch.device("cuda") | ||
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) | ||
A_scale = torch.randn((m,), dtype=torch.half, device=dev) | ||
B = torch.randint( | ||
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev | ||
) | ||
B_scale = torch.randn((n,), dtype=torch.half, device=dev) | ||
C = None | ||
|
||
return A, A_scale, B, B_scale, C | ||
|
||
def get_blockwise_problem(m: int, n: int, k: int, block_size: int): | ||
assert n % block_size == 0 and k % block_size == 0, "N and K dims must be divisible by block_size" | ||
dev = torch.device("cuda") | ||
A = (448.0 * (2 * torch.rand(m, k, device=dev) - 1)).to(torch.float8_e4m3fn) | ||
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=dev) | ||
B = (448.0 * (2 * torch.rand(n, k, device=dev) - 1)).to(torch.float8_e4m3fn) | ||
B_scale = torch.randn((n // block_size, k // block_size), dtype=torch.half, device=dev) | ||
|
||
return A, A_scale, B, B_scale | ||
|
||
def benchmark(m: int, k: int, n: int, block_size: int): | ||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((n, k), dtype=torch.half, device=dev) | ||
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) | ||
|
||
A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k, 8, 8) | ||
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( | ||
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C | ||
) | ||
|
||
A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size) | ||
blockwise_fp8_gemm_time = benchmark_microseconds( | ||
blockwise_fp8_gemm, A, A_scale, B, B_scale | ||
) | ||
|
||
# Add precision tests | ||
# On prend 2 sets de matrices aléatoires | ||
# On les quantise en int8/int4 rowwise | ||
# On les quantise en en float8 blockwise | ||
# | ||
|
||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"fp16_latency (ms)": fp16_time, | ||
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, | ||
"rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, | ||
"blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time, | ||
"blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time, | ||
} | ||
|
||
|
||
from torchao.prototype.blockwise_fp8.blockwise_quantization import fp8_blockwise_weight_quant, fp8_blockwise_act_quant, fp8_blockwise_weight_dequant | ||
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm | ||
|
||
def test_quant_dequant(): | ||
torch.manual_seed(0) | ||
x = torch.randn(256, 256).cuda() | ||
qx, s = fp8_blockwise_weight_quant(x, block_size=128) | ||
x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128) | ||
|
||
error = torch.norm(x - x_reconstructed) / torch.norm(x) | ||
print(f"Relative Error: {error.item():.6f}") | ||
|
||
assert error < 0.05, "Quant-Dequant error too high!" | ||
|
||
def test_blockwise_fp8_gemm(): | ||
torch.manual_seed(0) | ||
M, N, K = 256, 256, 128 | ||
A = torch.randn(M, K).cuda() | ||
B = torch.randn(N, K).cuda() | ||
|
||
C = A @ B.T | ||
|
||
A_q, A_s = fp8_blockwise_act_quant(A, block_size=128) | ||
B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128) | ||
|
||
C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s) | ||
|
||
error = torch.norm(C - C_q) / torch.norm(C) | ||
print(f"Relative Error: {error.item():.6f}") | ||
|
||
assert error < 0.05, "Quantized GEMM error is too high!" | ||
|
||
|
||
test_quant_dequant() | ||
test_blockwise_fp8_gemm() | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
block_size_vals = (128, 128, 128, 128) | ||
|
||
results = [] | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k, block_size in zip(n_vals, k_vals, block_size_vals): | ||
results.append(benchmark(m, k, n, block_size)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("blockwise_scaled_linear_triton_time_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is benchmarking
torch._scaled_mm
which does not support blockwise scaling, is this change intended?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintended, but I will rework this PR. There were some details that I had missed when I initially worked on it.