Skip to content

Commit

Permalink
RMSNorm (#5630)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Aug 15, 2024
1 parent abad90c commit fdf0df3
Show file tree
Hide file tree
Showing 15 changed files with 669 additions and 7 deletions.
21 changes: 21 additions & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
* [Reorg](#reorg)
* [Requantize](#requantize)
* [Reshape](#reshape)
* [RMSNorm](#rmsnorm)
* [RNN](#rnn)
* [Scale](#scale)
* [SELU](#selu)
Expand Down Expand Up @@ -1670,6 +1671,26 @@ Reshape flag:
- -1 = remaining
- -233 = drop this dim(default)

# RMSNorm
```
split x along outmost axis into part x0, x1 ...
root mean square normalize for each part x0, x1 ...
y = x * gamma by elementwise
```

* one_blob_only
* support_inplace

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | affine_size | int | 0 | |
| 1 | eps | float | 0.001f | x = x / sqrt(var + eps) |
| 2 | affine | int | 1 | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| gamma_data | float | [affine_size] |

# RNN
Apply a single-layer RNN to a feature sequence of `T` timesteps. The input blob shape is `[w=input_size, h=T]` and the output blob shape is `[w=num_output, h=T]`.

Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ ncnn_add_layer(Erf)
ncnn_add_layer(Diag)
ncnn_add_layer(CELU)
ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)
Expand Down
200 changes: 200 additions & 0 deletions src/layer/rmsnorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "rmsnorm.h"

namespace ncnn {

RMSNorm::RMSNorm()
{
one_blob_only = true;
support_inplace = true;
}

int RMSNorm::load_param(const ParamDict& pd)
{
affine_size = pd.get(0, 0);
eps = pd.get(1, 0.001f);
affine = pd.get(2, 1);

return 0;
}

int RMSNorm::load_model(const ModelBin& mb)
{
if (affine == 0)
return 0;

gamma_data = mb.load(affine_size, 1);
if (gamma_data.empty())
return -100;

return 0;
}

int RMSNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
// x = x / sqrt(rms + eps) * gamma

int dims = bottom_top_blob.dims;

if (dims == 1)
{
int w = bottom_top_blob.w;
// assert affine_size == w

float* ptr = bottom_top_blob;

float sqsum = 0.f;
for (int i = 0; i < w; i++)
{
sqsum += ptr[i] * ptr[i];
}
float rms = sqrtf(sqsum / w + eps);

float a = 1.f / rms;

if (affine)
{
for (int i = 0; i < w; i++)
{
ptr[i] = (ptr[i] * a) * gamma_data[i];
}
}
else
{
for (int i = 0; i < w; i++)
{
ptr[i] = ptr[i] * a;
}
}
}

if (dims == 2)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
// assert affine_size == w

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.row(i);

float sqsum = 0.f;
for (int j = 0; j < w; j++)
{
sqsum += ptr[j] * ptr[j];
}
float rms = sqrtf(sqsum / w + eps);

float a = 1.f / rms;

if (affine)
{
for (int j = 0; j < w; j++)
{
ptr[j] = (ptr[j] * a) * gamma_data[j];
}
}
else
{
for (int j = 0; j < w; j++)
{
ptr[j] = ptr[j] * a;
}
}
}
}

if (dims == 3)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int channels = bottom_top_blob.c;
int size = w * h;

if (affine_size == w)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.channel(q).row(i);

float sqsum = 0.f;
for (int j = 0; j < w; j++)
{
sqsum += ptr[j] * ptr[j];
}
float rms = sqrtf(sqsum / w + eps);

float a = 1.f / rms;

if (affine)
{
for (int j = 0; j < w; j++)
{
ptr[j] = (ptr[j] * a) * gamma_data[j];
}
}
else
{
for (int j = 0; j < w; j++)
{
ptr[j] = ptr[j] * a;
}
}
}
}
}
else // if (affine_size == size)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

float sqsum = 0.f;
for (int i = 0; i < size; i++)
{
sqsum += ptr[i] * ptr[i];
}
float rms = sqrtf(sqsum / size + eps);

float a = 1.f / rms;

if (affine)
{
for (int i = 0; i < size; i++)
{
ptr[i] = (ptr[i] * a) * gamma_data[i];
}
}
else
{
for (int i = 0; i < size; i++)
{
ptr[i] = ptr[i] * a;
}
}
}
}
}

return 0;
}

} // namespace ncnn
43 changes: 43 additions & 0 deletions src/layer/rmsnorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_RMSNORM_H
#define LAYER_RMSNORM_H

#include "layer.h"

namespace ncnn {

class RMSNorm : public Layer
{
public:
RMSNorm();

virtual int load_param(const ParamDict& pd);

virtual int load_model(const ModelBin& mb);

virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;

public:
int affine_size;
float eps;
int affine;

Mat gamma_data;
};

} // namespace ncnn

#endif // LAYER_RMSNORM_H
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ ncnn_add_layer_test(ReLU)
ncnn_add_layer_test(Reorg)
ncnn_add_layer_test(Requantize)
ncnn_add_layer_test(Reshape)
ncnn_add_layer_test(RMSNorm)
ncnn_add_layer_test(RNN)
ncnn_add_layer_test(ROIPooling)
ncnn_add_layer_test(ROIAlign)
Expand Down
Loading

0 comments on commit fdf0df3

Please sign in to comment.