From 60233aac6f6e32fff465e59017dd0b723bed091b Mon Sep 17 00:00:00 2001 From: jason-huang03 Date: Sat, 7 Sep 2024 09:32:56 +0800 Subject: [PATCH] [Update] Add rms norm general kernel and update sampler condition --- kernels/csrc/layernorm.cpp | 23 +++ kernels/csrc/layernorm_kernels.cu | 271 +++++++++++++++++++++++++++++- qserve/modeling/layers/sampler.py | 2 +- qserve_benchmark.py | 2 +- 4 files changed, 293 insertions(+), 5 deletions(-) diff --git a/kernels/csrc/layernorm.cpp b/kernels/csrc/layernorm.cpp index e494236..d37073c 100644 --- a/kernels/csrc/layernorm.cpp +++ b/kernels/csrc/layernorm.cpp @@ -14,6 +14,13 @@ void rms_norm(torch::Tensor &out, // [num_tokens, hidden_size] torch::Tensor &weight, // [hidden_size] float epsilon, bool use_quant); +void layer_norm_general(torch::Tensor &out, // [..., hidden_size] + torch::Tensor &input, // [..., hidden_size] + torch::Tensor &weight, // [hidden_size] + torch::Tensor &scaling, // [tokens] or [1] + float epsilon, + bool use_per_token_quant); + void rms_norm_general(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] @@ -21,6 +28,14 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] float epsilon, bool use_per_token_quant); +void layer_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] + torch::Tensor &input, // [..., hidden_size] + torch::Tensor &weight, // [hidden_size] + torch::Tensor &input_sum, // [tokens] or [1] + torch::Tensor &scaling, // [tokens] or [1] + float epsilon, + bool use_per_token_quant); + void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] @@ -49,10 +64,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("epsilon"), py::arg("use_quant") = false, "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def("layer_norm_general", &layer_norm_general, py::arg("out"), py::arg("input"), + py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false, + "Apply Layer Normalization to the input tensor (modified from TRTLLM kernel)."); + m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"), py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false, "Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel)."); + m.def("layer_norm_general_fuse_sum", &layer_norm_general_fuse_sum, py::arg("out"), py::arg("input"), + py::arg("weight"), py::arg("input_sum"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false, + "Apply Layer Normalization to the input tensor & get input sum (modified from TRTLLM kernel)."); + m.def("rms_norm_general_fuse_sum", &rms_norm_general_fuse_sum, py::arg("out"), py::arg("input"), py::arg("weight"), py::arg("input_sum"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false, "Apply Root Mean Square (RMS) Normalization to the input tensor & get input sum (TRTLLM kernel)."); diff --git a/kernels/csrc/layernorm_kernels.cu b/kernels/csrc/layernorm_kernels.cu index acd282d..ca44575 100644 --- a/kernels/csrc/layernorm_kernels.cu +++ b/kernels/csrc/layernorm_kernels.cu @@ -185,6 +185,96 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, } } +template +__global__ void generalRMSNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, + int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token, + int8_t* normed_output_quant, bool use_shmem) +{ + constexpr auto num_elems_T = num_elems::value; + using int8_packed_t = typename packed_as::type; + using float_packed_t = typename packed_as::type; + using T_scalar = typename packed_as::type; + + extern __shared__ __align__(sizeof(float)) char _shmem[]; + T* shmem = reinterpret_cast(_shmem); + __shared__ float s_mean; + __shared__ float s_variance; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + float variance = 0.0f; + float local_var_sum = 0.0f; + + const int n_elems = hidden_dim / num_elems_T; + + for (int i = tidx; i < n_elems; i += blockDim.x) + { + const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i]; + float_packed_t diff = cuda_cast(val); // no mean + local_var_sum += cuda_sum(diff * diff); + } + variance = blockReduceSum(local_var_sum); + + if (threadIdx.x == 0) + { + s_variance = rsqrtf(variance / hidden_dim + eps); + } + __syncthreads(); + + const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; + const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; + const float_packed_t scale_orig_quant + = cuda_cast(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f); + T_scalar amax = 1e-6f; + + for (int i = tidx; i < n_elems; i += blockDim.x) + { + const int index = bidx * n_elems + i; + const float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); + const T val = cuda_cast(compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i)); + + if (with_per_token_scaling) + { + amax = cuda_max(cuda_max(cuda_abs(val)), amax); + if (use_shmem) + { + shmem[i] = val; + } + } + else if (with_per_tensor_scaling) + { + reinterpret_cast(normed_output_quant)[index] + = cuda_cast(cuda_cast(val) * scale_orig_quant); + } + else + { + normed_output[index] = val; + } + } + + if (with_per_token_scaling) + { + float abs_max_f = blockAllReduceMax(cuda_cast(amax)); + const float dynamic_per_token_scale = 127.f / abs_max_f; + for (int i = tidx; i < n_elems; i += blockDim.x) + { + const int index = bidx * n_elems + i; + float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); + if (!use_shmem) + { + val_f = compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i); + } + + reinterpret_cast(normed_output_quant)[index] + = cuda_cast(val_f * cuda_cast(dynamic_per_token_scale)); + } + if (tidx == 0) + { + scale_orig_quant_per_token[bidx] = abs_max_f / 127.f; + } + } +} template __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, @@ -325,6 +415,100 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const } } +template +__global__ void generalRMSNorm_fuse_sum(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, + int tokens, int hidden_dim, scale_type* input_sum, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token, + int8_t* normed_output_quant, bool use_shmem) +{ + constexpr auto num_elems_T = num_elems::value; + using int8_packed_t = typename packed_as::type; + using float_packed_t = typename packed_as::type; + using T_scalar = typename packed_as::type; + + extern __shared__ __align__(sizeof(float)) char _shmem[]; + T* shmem = reinterpret_cast(_shmem); + __shared__ float s_mean; + __shared__ float s_variance; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + float variance = 0.0f; + float local_var_sum = 0.0f; + + const int n_elems = hidden_dim / num_elems_T; + + for (int i = tidx; i < n_elems; i += blockDim.x) + { + const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i]; + float_packed_t diff = cuda_cast(val); // no mean + local_var_sum += cuda_sum(diff * diff); + } + variance = blockReduceSum(local_var_sum); + + if (threadIdx.x == 0) + { + s_variance = rsqrtf(variance / hidden_dim + eps); + } + __syncthreads(); + + const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; + const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; + const float_packed_t scale_orig_quant + = cuda_cast(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f); + T_scalar amax = 1e-6f; + T_scalar sum = 0.0f; + + for (int i = tidx; i < n_elems; i += blockDim.x) + { + const int index = bidx * n_elems + i; + const float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); + const T val = cuda_cast(compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i)); + + if (with_per_token_scaling) + { + amax = cuda_max(cuda_max(cuda_abs(val)), amax); + sum += cuda_sum(val); + if (use_shmem) + { + shmem[i] = val; + } + } + else if (with_per_tensor_scaling) + { + reinterpret_cast(normed_output_quant)[index] + = cuda_cast(cuda_cast(val) * scale_orig_quant); + } + else + { + normed_output[index] = val; + } + } + + if (with_per_token_scaling) + { + float abs_max_f = blockAllReduceMax(cuda_cast(amax)); + float sum_f = blockAllReduceSum(cuda_cast(sum)); + const float dynamic_per_token_scale = 127.f / abs_max_f; + for (int i = tidx; i < n_elems; i += blockDim.x) + { + const int index = bidx * n_elems + i; + float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); + if (!use_shmem) + { + val_f = compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i); + } + + reinterpret_cast(normed_output_quant)[index] + = cuda_cast(val_f * cuda_cast(dynamic_per_token_scale)); + } + if (tidx == 0) + { + scale_orig_quant_per_token[bidx] = abs_max_f / 127.f; + input_sum[bidx] = sum_f; + } + } +} // TODO(woosuk): Further optimize this kernel. template @@ -424,7 +608,7 @@ void rms_norm(torch::Tensor &out, // [..., hidden_size] }); } -void rms_norm_general(torch::Tensor &out, // [..., hidden_size] +void layer_norm_general(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] torch::Tensor &scaling, // [tokens] or [1] @@ -463,7 +647,46 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] }); } -void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] +void rms_norm_general(torch::Tensor &out, // [..., hidden_size] + torch::Tensor &input, // [..., hidden_size] + torch::Tensor &weight, // [hidden_size] + torch::Tensor &scaling, // [tokens] or [1] + float epsilon, + bool use_per_token_quant) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + block.x = 32 * ((block.x + 31) / 32); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalRMSNorm", [&] { + using T = typename FloatTypeConverter::Type; + if (use_per_token_quant) { + // per-token + vllm::generalRMSNorm<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(weight.data_ptr()), nullptr, + nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), + out.data_ptr(), false + ); + // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale + // normed_output_quant, use_shmem + // out.data_ptr(), input.data_ptr(), + // weight.data_ptr(), epsilon, num_tokens, hidden_size); + } else { + // per-tensor + vllm::generalRMSNorm<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(weight.data_ptr()), nullptr, + nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr(), nullptr, + out.data_ptr(), false + ); + } + }); +} + +void layer_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] torch::Tensor &input_sum, // [tokens] or [1] @@ -507,7 +730,49 @@ void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] }); } - +void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size] + torch::Tensor &input, // [..., hidden_size] + torch::Tensor &weight, // [hidden_size] + torch::Tensor &input_sum, // [tokens] or [1] + torch::Tensor &scaling, // [tokens] or [1] + float epsilon, + bool use_per_token_quant) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + block.x = 32 * ((block.x + 31) / 32); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalRMSNorm_fuse_sum", [&] { + using T = typename FloatTypeConverter::Type; + if (use_per_token_quant) { + // per-token + vllm::generalRMSNorm_fuse_sum<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(weight.data_ptr()), nullptr, + nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr(), nullptr, scaling.data_ptr(), + out.data_ptr(), false + ); + // input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale + // normed_output_quant, use_shmem + // out.data_ptr(), input.data_ptr(), + // weight.data_ptr(), epsilon, num_tokens, hidden_size); + } else { + // per-tensor + // Rasing error here + // Not implemented per-tensor input_sum + assert(false); + + vllm::generalRMSNorm_fuse_sum<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(weight.data_ptr()), nullptr, + nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), nullptr, + out.data_ptr(), false + ); + } + }); +} void invoke_dequant_add_residual_rms_norm_quant( torch::Tensor &out, // [..., hidden_size] diff --git a/qserve/modeling/layers/sampler.py b/qserve/modeling/layers/sampler.py index 07817c5..b46b345 100644 --- a/qserve/modeling/layers/sampler.py +++ b/qserve/modeling/layers/sampler.py @@ -84,7 +84,7 @@ def forward( else: last_token_logits = logits if ( - self.sampling_params.temperature < 1e-5 or self.sampling_params.top_p < 1e-8 # greedy + self.sampling_params.temperature < 1e-5 or self.sampling_params.top_p < 1e-8 or self.sampling_params.top_k == 1 # greedy ): token = torch.argmax(last_token_logits, dim=-1) else: diff --git a/qserve_benchmark.py b/qserve_benchmark.py index 23348bb..0acdfd4 100644 --- a/qserve_benchmark.py +++ b/qserve_benchmark.py @@ -32,7 +32,7 @@ def process_requests( str(b), prompt=None, profiling_config=profiling_config, - sampling_params=SamplingParams(top_p=0.95, top_k=40, temperature=0.7), + sampling_params=SamplingParams(temperature=0.0), ) if engine.ifb_mode == False: