-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathlayernorm.cpp
95 lines (82 loc) · 4.73 KB
/
layernorm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Inspired by TRT-LLM.
// Modified by Shang Yang and Haotian Tang.
// @article{lin2024qserve,
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and Han, Song},
// journal={arXiv preprint arXiv:2405.04532},
// year={2024}
// }
#include <torch/extension.h>
#include <cuda_fp16.h>
void rms_norm(torch::Tensor &out, // [num_tokens, hidden_size]
torch::Tensor &input, // [num_tokens, hidden_size]
torch::Tensor &weight, // [hidden_size]
float epsilon, bool use_quant);
void layer_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
void layer_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &input_sum, // [tokens] or [1]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &input_sum, // [tokens] or [1]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
void invoke_dequant_add_residual_rms_norm_quant(
torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &residual, // [..., hidden_size]
torch::Tensor &gamma, // [hidden_size]
at::Half scale, float epsilon);
void invoke_dequant_add_residual_rms_norm_quant(
torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &residual, // [..., hidden_size]
torch::Tensor &gamma, // [hidden_size]
torch::Tensor &scale, // [num_tokens]
float epsilon);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rms_norm", &rms_norm, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("epsilon"), py::arg("use_quant") = false,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def("layer_norm_general", &layer_norm_general, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Layer Normalization to the input tensor (modified from TRTLLM kernel).");
m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");
m.def("layer_norm_general_fuse_sum", &layer_norm_general_fuse_sum, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("input_sum"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Layer Normalization to the input tensor & get input sum (modified from TRTLLM kernel).");
m.def("rms_norm_general_fuse_sum", &rms_norm_general_fuse_sum, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("input_sum"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Root Mean Square (RMS) Normalization to the input tensor & get input sum (TRTLLM kernel).");
m.def("invoke_dequant_add_residual_rms_norm_quant",
py::overload_cast<torch::Tensor &, torch::Tensor &, torch::Tensor &,
torch::Tensor &, at::Half, float>(
&invoke_dequant_add_residual_rms_norm_quant),
"Add the dequanted result and residual, then use RMS norm and quant "
"output.");
m.def("invoke_dequant_add_residual_rms_norm_quant",
py::overload_cast<torch::Tensor &, torch::Tensor &, torch::Tensor &,
torch::Tensor &, torch::Tensor &, float>(
&invoke_dequant_add_residual_rms_norm_quant),
"Add the dequanted result and residual, then use RMS norm and quant "
"output.");
}