Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: reduce the read and write of shared memory in the FusedAddRMSNormKernel #592

Merged
merged 4 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions benchmarks/bench_fused_add_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import argparse
from typing import cast

import torch
from triton.testing import do_bench

import flashinfer

@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989])
parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192])
parser.add_argument("--dtypes", nargs='+', choices=["float16", "float32"], default=["float16"])
args = parser.parse_args()

eps = 1e-6

# Loop over each combination of batch_size, hidden_size, and dtype
for batch_size in args.batch_sizes:
for hidden_size in args.hidden_sizes:
for dtype_str in args.dtypes:
dtype = getattr(torch, dtype_str)

# Define tensors with the correct dtype
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")

@torch.cuda.nvtx.range(f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}")
def fn() -> None:
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)

# Run benchmarking
latency_ms = cast(float, do_bench(fn))
throughput = (
(x.numel() * x.element_size() * 2 + weight.numel() * weight.element_size()) / (latency_ms * 1e-3)
)
print(
f"batch_size: {batch_size:3},",
f"hidden_size: {hidden_size:5},",
f"dtype: {dtype_str:2},",
f"latency: {latency_ms*1e3:2.0f}us,",
f"throughput: {throughput*1e-9:7.3f}GB/s",
)

print("---")

torch.cuda.profiler.stop()

if __name__ == "__main__":
main()
11 changes: 8 additions & 3 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0.f);
vec_t<float, VEC_SIZE> x_vec;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this kernel in August #419, and you can actually use https://pytorch.org/docs/stable/benchmark_utils.html to add a benchmark. This way, you can know whether there is a performance improvement before and after the changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this kernel in August #419, and you can actually use https://pytorch.org/docs/stable/benchmark_utils.html to add a benchmark. This way, you can know whether there is a performance improvement before and after the changes.

Okay, I'll look into this, but I've analyzed this PR using Nsign Compute and found that the performance is about the same as the code before the precision improvement(#587).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay,I'll try to write a benchmark test like this.

x_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand All @@ -143,10 +145,11 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x;
x_vec[j] = x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

Expand Down Expand Up @@ -174,15 +177,17 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res
for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<float, VEC_SIZE> x_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
x_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j];
input_vec[j] = x * rms_rcp * float(weight_vec[j]);
input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
Expand Down