diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 05996f8d735..7594c0843ac 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -71,6 +71,7 @@ * [Reorg](#reorg) * [Requantize](#requantize) * [Reshape](#reshape) +* [RMSNorm](#rmsnorm) * [RNN](#rnn) * [Scale](#scale) * [SELU](#selu) @@ -1670,6 +1671,26 @@ Reshape flag: - -1 = remaining - -233 = drop this dim(default) +# RMSNorm +``` +split x along outmost axis into part x0, x1 ... +root mean square normalize for each part x0, x1 ... +y = x * gamma by elementwise +``` + +* one_blob_only +* support_inplace + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | affine_size | int | 0 | | +| 1 | eps | float | 0.001f | x = x / sqrt(var + eps) | +| 2 | affine | int | 1 | | + +| weight | type | shape | +| ------------- | ----- | --------------------- | +| gamma_data | float | [affine_size] | + # RNN Apply a single-layer RNN to a feature sequence of `T` timesteps. The input blob shape is `[w=input_size, h=T]` and the output blob shape is `[w=num_output, h=T]`. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d3f55ce7790..803c34a780d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -166,6 +166,7 @@ ncnn_add_layer(Erf) ncnn_add_layer(Diag) ncnn_add_layer(CELU) ncnn_add_layer(Shrink) +ncnn_add_layer(RMSNorm) if(NCNN_VULKAN) ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) diff --git a/src/layer/rmsnorm.cpp b/src/layer/rmsnorm.cpp new file mode 100644 index 00000000000..77c74c6bccb --- /dev/null +++ b/src/layer/rmsnorm.cpp @@ -0,0 +1,200 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "rmsnorm.h" + +namespace ncnn { + +RMSNorm::RMSNorm() +{ + one_blob_only = true; + support_inplace = true; +} + +int RMSNorm::load_param(const ParamDict& pd) +{ + affine_size = pd.get(0, 0); + eps = pd.get(1, 0.001f); + affine = pd.get(2, 1); + + return 0; +} + +int RMSNorm::load_model(const ModelBin& mb) +{ + if (affine == 0) + return 0; + + gamma_data = mb.load(affine_size, 1); + if (gamma_data.empty()) + return -100; + + return 0; +} + +int RMSNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const +{ + // x = x / sqrt(rms + eps) * gamma + + int dims = bottom_top_blob.dims; + + if (dims == 1) + { + int w = bottom_top_blob.w; + // assert affine_size == w + + float* ptr = bottom_top_blob; + + float sqsum = 0.f; + for (int i = 0; i < w; i++) + { + sqsum += ptr[i] * ptr[i]; + } + float rms = sqrtf(sqsum / w + eps); + + float a = 1.f / rms; + + if (affine) + { + for (int i = 0; i < w; i++) + { + ptr[i] = (ptr[i] * a) * gamma_data[i]; + } + } + else + { + for (int i = 0; i < w; i++) + { + ptr[i] = ptr[i] * a; + } + } + } + + if (dims == 2) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + // assert affine_size == w + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.row(i); + + float sqsum = 0.f; + for (int j = 0; j < w; j++) + { + sqsum += ptr[j] * ptr[j]; + } + float rms = sqrtf(sqsum / w + eps); + + float a = 1.f / rms; + + if (affine) + { + for (int j = 0; j < w; j++) + { + ptr[j] = (ptr[j] * a) * gamma_data[j]; + } + } + else + { + for (int j = 0; j < w; j++) + { + ptr[j] = ptr[j] * a; + } + } + } + } + + if (dims == 3) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int channels = bottom_top_blob.c; + int size = w * h; + + if (affine_size == w) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + float* ptr = bottom_top_blob.channel(q).row(i); + + float sqsum = 0.f; + for (int j = 0; j < w; j++) + { + sqsum += ptr[j] * ptr[j]; + } + float rms = sqrtf(sqsum / w + eps); + + float a = 1.f / rms; + + if (affine) + { + for (int j = 0; j < w; j++) + { + ptr[j] = (ptr[j] * a) * gamma_data[j]; + } + } + else + { + for (int j = 0; j < w; j++) + { + ptr[j] = ptr[j] * a; + } + } + } + } + } + else // if (affine_size == size) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < channels; q++) + { + float* ptr = bottom_top_blob.channel(q); + + float sqsum = 0.f; + for (int i = 0; i < size; i++) + { + sqsum += ptr[i] * ptr[i]; + } + float rms = sqrtf(sqsum / size + eps); + + float a = 1.f / rms; + + if (affine) + { + for (int i = 0; i < size; i++) + { + ptr[i] = (ptr[i] * a) * gamma_data[i]; + } + } + else + { + for (int i = 0; i < size; i++) + { + ptr[i] = ptr[i] * a; + } + } + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/rmsnorm.h b/src/layer/rmsnorm.h new file mode 100644 index 00000000000..4a09f2548bd --- /dev/null +++ b/src/layer/rmsnorm.h @@ -0,0 +1,43 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_RMSNORM_H +#define LAYER_RMSNORM_H + +#include "layer.h" + +namespace ncnn { + +class RMSNorm : public Layer +{ +public: + RMSNorm(); + + virtual int load_param(const ParamDict& pd); + + virtual int load_model(const ModelBin& mb); + + virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; + +public: + int affine_size; + float eps; + int affine; + + Mat gamma_data; +}; + +} // namespace ncnn + +#endif // LAYER_RMSNORM_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d30229b870c..6c8939fc7c7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -141,6 +141,7 @@ ncnn_add_layer_test(ReLU) ncnn_add_layer_test(Reorg) ncnn_add_layer_test(Requantize) ncnn_add_layer_test(Reshape) +ncnn_add_layer_test(RMSNorm) ncnn_add_layer_test(RNN) ncnn_add_layer_test(ROIPooling) ncnn_add_layer_test(ROIAlign) diff --git a/tests/test_rmsnorm.cpp b/tests/test_rmsnorm.cpp new file mode 100644 index 00000000000..2d88c162d8b --- /dev/null +++ b/tests/test_rmsnorm.cpp @@ -0,0 +1,121 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_rmsnorm(const ncnn::Mat& a, int affine_size, float eps, int affine) +{ + ncnn::ParamDict pd; + pd.set(0, affine_size); + pd.set(1, eps); + pd.set(2, affine); + + std::vector weights(1); + weights[0] = RandomMat(affine_size); + + int ret = test_layer("RMSNorm", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_rmsnorm failed a.dims=%d a=(%d %d %d) affine_size=%d eps=%f affine=%d\n", a.dims, a.w, a.h, a.c, affine_size, eps, affine); + } + + return ret; +} + +static int test_rmsnorm_0() +{ + return 0 + || test_rmsnorm(RandomMat(6, 4, 2), 6, 0.01f, 0) + || test_rmsnorm(RandomMat(4, 5, 6), 4, 0.01f, 0) + || test_rmsnorm(RandomMat(3, 3, 8), 3, 0.002f, 0) + || test_rmsnorm(RandomMat(5, 6, 12), 5, 0.02f, 0) + || test_rmsnorm(RandomMat(4, 7, 16), 4, 0.02f, 0) + || test_rmsnorm(RandomMat(6, 7, 24), 6, 0.001f, 0) + || test_rmsnorm(RandomMat(5, 8, 32), 5, 0.001f, 0) + || test_rmsnorm(RandomMat(6, 4, 2), 6, 0.01f, 1) + || test_rmsnorm(RandomMat(4, 5, 6), 4, 0.01f, 1) + || test_rmsnorm(RandomMat(3, 3, 8), 3, 0.002f, 1) + || test_rmsnorm(RandomMat(5, 6, 12), 5, 0.02f, 1) + || test_rmsnorm(RandomMat(4, 7, 16), 4, 0.02f, 1) + || test_rmsnorm(RandomMat(6, 7, 24), 6, 0.001f, 1) + || test_rmsnorm(RandomMat(5, 8, 32), 5, 0.001f, 1); +} + +static int test_rmsnorm_1() +{ + return 0 + || test_rmsnorm(RandomMat(6, 4, 2), 24, 0.01f, 0) + || test_rmsnorm(RandomMat(4, 5, 6), 20, 0.01f, 0) + || test_rmsnorm(RandomMat(3, 3, 8), 9, 0.002f, 0) + || test_rmsnorm(RandomMat(5, 6, 12), 30, 0.02f, 0) + || test_rmsnorm(RandomMat(4, 7, 16), 28, 0.02f, 0) + || test_rmsnorm(RandomMat(6, 7, 24), 42, 0.001f, 0) + || test_rmsnorm(RandomMat(5, 8, 32), 40, 0.001f, 0) + || test_rmsnorm(RandomMat(6, 4, 2), 24, 0.01f, 1) + || test_rmsnorm(RandomMat(4, 5, 6), 20, 0.01f, 1) + || test_rmsnorm(RandomMat(3, 3, 8), 9, 0.002f, 1) + || test_rmsnorm(RandomMat(5, 6, 12), 30, 0.02f, 1) + || test_rmsnorm(RandomMat(4, 7, 16), 28, 0.02f, 1) + || test_rmsnorm(RandomMat(6, 7, 24), 42, 0.001f, 1) + || test_rmsnorm(RandomMat(5, 8, 32), 40, 0.001f, 1); +} + +static int test_rmsnorm_2() +{ + return 0 + || test_rmsnorm(RandomMat(4, 2), 4, 0.01f, 0) + || test_rmsnorm(RandomMat(5, 6), 5, 0.01f, 0) + || test_rmsnorm(RandomMat(3, 8), 3, 0.002f, 0) + || test_rmsnorm(RandomMat(6, 12), 6, 0.02f, 0) + || test_rmsnorm(RandomMat(4, 16), 4, 0.02f, 0) + || test_rmsnorm(RandomMat(7, 24), 7, 0.001f, 0) + || test_rmsnorm(RandomMat(8, 32), 8, 0.001f, 0) + || test_rmsnorm(RandomMat(4, 2), 4, 0.01f, 1) + || test_rmsnorm(RandomMat(5, 6), 5, 0.01f, 1) + || test_rmsnorm(RandomMat(3, 8), 3, 0.002f, 1) + || test_rmsnorm(RandomMat(6, 12), 6, 0.02f, 1) + || test_rmsnorm(RandomMat(4, 16), 4, 0.02f, 1) + || test_rmsnorm(RandomMat(7, 24), 7, 0.001f, 1) + || test_rmsnorm(RandomMat(8, 32), 8, 0.001f, 1); +} + +static int test_rmsnorm_3() +{ + return 0 + || test_rmsnorm(RandomMat(2), 2, 0.01f, 0) + || test_rmsnorm(RandomMat(6), 6, 0.01f, 0) + || test_rmsnorm(RandomMat(8), 8, 0.002f, 0) + || test_rmsnorm(RandomMat(12), 12, 0.02f, 0) + || test_rmsnorm(RandomMat(16), 16, 0.02f, 0) + || test_rmsnorm(RandomMat(24), 24, 0.001f, 0) + || test_rmsnorm(RandomMat(32), 32, 0.001f, 0) + || test_rmsnorm(RandomMat(2), 2, 0.01f, 1) + || test_rmsnorm(RandomMat(6), 6, 0.01f, 1) + || test_rmsnorm(RandomMat(8), 8, 0.002f, 1) + || test_rmsnorm(RandomMat(12), 12, 0.02f, 1) + || test_rmsnorm(RandomMat(16), 16, 0.02f, 1) + || test_rmsnorm(RandomMat(24), 24, 0.001f, 1) + || test_rmsnorm(RandomMat(32), 32, 0.001f, 1); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_rmsnorm_0() + || test_rmsnorm_1() + || test_rmsnorm_2() + || test_rmsnorm_3(); +} diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 9834fabe069..2c814bd486c 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -475,6 +475,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/F_prelu.cpp pass_ncnn/F_relu.cpp pass_ncnn/F_relu6.cpp + pass_ncnn/F_rms_norm.cpp pass_ncnn/F_scaled_dot_product_attention.cpp pass_ncnn/F_selu.cpp pass_ncnn/F_sigmoid.cpp @@ -541,6 +542,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/nn_ReplicationPad1d.cpp pass_ncnn/nn_ReplicationPad2d.cpp pass_ncnn/nn_ReplicationPad3d.cpp + pass_ncnn/nn_RMSNorm.cpp pass_ncnn/nn_RNN.cpp pass_ncnn/nn_SELU.cpp pass_ncnn/nn_Sigmoid.cpp diff --git a/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp b/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp index 4433f598935..498f0453c14 100644 --- a/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp +++ b/tools/pnnx/src/pass_level1/nn_RMSNorm.cpp @@ -37,7 +37,7 @@ class RMSNorm : public FuseModulePass op->params["normalized_shape"] = rmsn->namedInput("normalized_shape"); op->params["eps"] = rmsn->namedInput("eps"); - op->params["elementwise_affine"] = mod.hasattr("weight") && mod.hasattr("bias"); + op->params["elementwise_affine"] = mod.hasattr("weight"); if (mod.hasattr("weight")) { diff --git a/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp b/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp new file mode 100644 index 00000000000..8230168312c --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/F_rms_norm.cpp @@ -0,0 +1,65 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class F_rms_norm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +F.rms_norm op_0 1 1 input out weight=None normalized_shape=%normalized_shape eps=%eps +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "RMSNorm"; + } + + const char* name_str() const + { + return "rmsn"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::vector& normalized_shape = captured_params.at("normalized_shape").ai; + int affine_size = normalized_shape[0]; + for (size_t i = 1; i < normalized_shape.size(); i++) + { + affine_size *= normalized_shape[i]; + } + + const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f; + + op->params["0"] = affine_size; + op->params["1"] = eps; + op->params["2"] = 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_rms_norm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp b/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp new file mode 100644 index 00000000000..7fda637c5ca --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp @@ -0,0 +1,70 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class nn_RMSNorm : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.RMSNorm op_0 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=%elementwise_affine @weight +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "RMSNorm"; + } + + const char* name_str() const + { + return "rmsn"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector& normalized_shape = captured_params.at("normalized_shape").ai; + int affine_size = normalized_shape[0]; + for (size_t i = 1; i < normalized_shape.size(); i++) + { + affine_size *= normalized_shape[i]; + } + + const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f; + + op->params["0"] = affine_size; + op->params["1"] = eps; + op->params["2"] = captured_params.at("elementwise_affine").b ? 1 : 0; + + if (captured_params.at("elementwise_affine").b) + { + op->attrs["0"] = captured_attrs.at("op_0.weight"); + } + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RMSNorm, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index a60e63eb54b..49cb063f335 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -53,6 +53,7 @@ pnnx_ncnn_add_test(F_pixel_unshuffle) pnnx_ncnn_add_test(F_prelu) pnnx_ncnn_add_test(F_relu) pnnx_ncnn_add_test(F_relu6) +pnnx_ncnn_add_test(F_rms_norm) pnnx_ncnn_add_test(F_selu) pnnx_ncnn_add_test(F_sigmoid) pnnx_ncnn_add_test(F_silu) @@ -123,6 +124,7 @@ pnnx_ncnn_add_test(nn_ReLU6) pnnx_ncnn_add_test(nn_ReplicationPad1d) pnnx_ncnn_add_test(nn_ReplicationPad2d) pnnx_ncnn_add_test(nn_ReplicationPad3d) +pnnx_ncnn_add_test(nn_RMSNorm) pnnx_ncnn_add_test(nn_RNN) pnnx_ncnn_add_test(nn_SELU) pnnx_ncnn_add_test(nn_Sigmoid) diff --git a/tools/pnnx/tests/ncnn/test_F_layer_norm.py b/tools/pnnx/tests/ncnn/test_F_layer_norm.py index 92244f17910..9d590aa76dd 100644 --- a/tools/pnnx/tests/ncnn/test_F_layer_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_layer_norm.py @@ -37,8 +37,8 @@ def test(): net.eval() torch.manual_seed(0) - x = torch.rand(12, 24) - y = torch.rand(3, 12, 16) + x = torch.rand(1, 12, 24) + y = torch.rand(1, 3, 12, 16) a = net(x, y) @@ -48,7 +48,7 @@ def test(): # torchscript to pnnx import os - os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[12,24],[3,12,16]") + os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[1,12,24],[1,3,12,16]") # ncnn inference import test_F_layer_norm_ncnn diff --git a/tools/pnnx/tests/ncnn/test_F_rms_norm.py b/tools/pnnx/tests/ncnn/test_F_rms_norm.py new file mode 100644 index 00000000000..4e60d9314aa --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_F_rms_norm.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.w3 = nn.Parameter(torch.rand(24)) + self.w4 = nn.Parameter(torch.rand(12, 16)) + + def forward(self, x, y): + x = F.rms_norm(x, (24,), self.w3) + + y = F.rms_norm(y, (16,), None) + z = F.rms_norm(y, (12,16), self.w4, eps=1e-3) + return x, y, z + +def test(): + if version.parse(torch.__version__) < version.parse('2.4'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 12, 24) + y = torch.rand(1, 3, 12, 16) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_F_rms_norm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[1,3,12,16]") + + # ncnn inference + import test_F_rms_norm_ncnn + b = test_F_rms_norm_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py index a45444060d0..d409bdfba3a 100644 --- a/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_LayerNorm.py @@ -36,8 +36,8 @@ def test(): net.eval() torch.manual_seed(0) - x = torch.rand(24, 64) - y = torch.rand(12, 24, 64) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) a = net(x, y) @@ -47,7 +47,7 @@ def test(): # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[24,64],[12,24,64]") + os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[1,24,64],[1,12,24,64]") # ncnn inference import test_nn_LayerNorm_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py new file mode 100644 index 00000000000..0d5efa211e4 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_nn_RMSNorm.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.rmsn_0 = nn.RMSNorm(64) + self.rmsn_0.weight = nn.Parameter(torch.rand(64)) + self.rmsn_1 = nn.RMSNorm(normalized_shape=(24,64), eps=1e-2, elementwise_affine=False) + + def forward(self, x, y): + x = self.rmsn_0(x) + y = self.rmsn_0(y) + z = self.rmsn_1(y) + return x, y, z + +def test(): + if version.parse(torch.__version__) < version.parse('2.4'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 24, 64) + y = torch.rand(1, 12, 24, 64) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_nn_RMSNorm.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_nn_RMSNorm.pt inputshape=[1,24,64],[1,12,24,64]") + + # ncnn inference + import test_nn_RMSNorm_ncnn + b = test_nn_RMSNorm_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)