-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged PR 1136: Add BatchNormalization operator.
Related work items: #57
- Loading branch information
Pranav Sharma
authored and
Pranav Sharma
committed
Mar 29, 2018
1 parent
d32b606
commit cfe526a
Showing
6 changed files
with
470 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#include "core/providers/cpu/nn/batch_norm.h" | ||
|
||
namespace Lotus { | ||
// spec: https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization | ||
REGISTER_KERNEL(KernelDef("BatchNormalization") | ||
.Domain(LotusIR::kOnnxDomain) | ||
// This operator is used if you are using version 6 of the default ONNX operator | ||
// set until the next BC-breaking change to this operator | ||
.SinceVersion(6, 7) | ||
.Provider(LotusIR::kCpuExecutionProvider) | ||
.TypeConstraint("X", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("scale", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("B", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("mean", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("var", DataTypeImpl::GetTensorType<float>()), | ||
BatchNorm<float>); | ||
|
||
template <> | ||
Status BatchNorm<float>::ValidateInputs(const Tensor* X, | ||
const Tensor* scale, | ||
const Tensor* B, | ||
const Tensor* mean, | ||
const Tensor* var) const { | ||
if (X->shape().NumDimensions() != kNumInputXDimensions) { | ||
std::ostringstream ostr; | ||
ostr << "Invalid input X: NumDimensions() != " << kNumInputXDimensions; | ||
return Status(LOTUS, INVALID_ARGUMENT, ostr.str()); | ||
} | ||
if (scale->shape().NumDimensions() != kNumInputScaleDimensions) { | ||
std::ostringstream ostr; | ||
ostr << "Invalid input scale: NumDimensions() != " << kNumInputScaleDimensions; | ||
return Status(LOTUS, INVALID_ARGUMENT, ostr.str()); | ||
} | ||
if (B->shape().NumDimensions() != kNumInputBiasDimensions) { | ||
std::ostringstream ostr; | ||
ostr << "Invalid input B: NumDimensions() != " << kNumInputBiasDimensions; | ||
return Status(LOTUS, INVALID_ARGUMENT, ostr.str()); | ||
} | ||
if (mean->shape().NumDimensions() != kNumInputMeanDimensions) { | ||
std::ostringstream ostr; | ||
ostr << "Invalid input mean: NumDimensions() != " << kNumInputMeanDimensions; | ||
return Status(LOTUS, INVALID_ARGUMENT, ostr.str()); | ||
} | ||
if (var->shape().NumDimensions() != kNumInputVarianceDimensions) { | ||
std::ostringstream ostr; | ||
ostr << "Invalid input var: NumDimensions() != " << kNumInputVarianceDimensions; | ||
return Status(LOTUS, INVALID_ARGUMENT, ostr.str()); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
template <> | ||
Status BatchNorm<float>::compute(OpKernelContext* p_op_kernel_context) const { | ||
const Tensor* X = p_op_kernel_context->input<Tensor>(0); | ||
const Tensor* scale = p_op_kernel_context->input<Tensor>(1); | ||
const Tensor* B = p_op_kernel_context->input<Tensor>(2); | ||
const Tensor* mean = p_op_kernel_context->input<Tensor>(3); | ||
const Tensor* var = p_op_kernel_context->input<Tensor>(4); | ||
|
||
LOTUS_RETURN_IF_ERROR(ValidateInputs(X, scale, B, mean, var)); | ||
|
||
const TensorShape& x_shape = X->shape(); | ||
Tensor* Y = p_op_kernel_context->output(0, x_shape); | ||
|
||
const size_t N = x_shape[0]; | ||
const size_t C = x_shape[1]; // assume NCHW as per the spec | ||
const size_t H = x_shape[2]; | ||
const size_t W = x_shape[3]; | ||
|
||
const size_t sample_size = H * W; | ||
|
||
ConstEigenVectorArrayMap<float> scale_arr(scale->data<float>(), C); | ||
ConstEigenVectorArrayMap<float> bias_arr(B->data<float>(), C); | ||
|
||
// Regardless of training or testing, we will apply the estimated mean | ||
// and standard deviation to the input. For testing, they are | ||
// specified directly by the input, and for training, they are computed | ||
// by the op. | ||
Eigen::Array<float, Eigen::Dynamic, 1> inv_std(C); | ||
ConstEigenVectorArrayMap<float> var_arr(var->data<float>(), C); | ||
inv_std = (var_arr + epsilon_).sqrt().inverse(); | ||
ConstEigenVectorArrayMap<float> mean_arr(mean->data<float>(), C); | ||
// We can fuse the output computation as follows: | ||
// ((x - est_mean) * (inv_var) * scale + bias | ||
// to | ||
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale) | ||
Eigen::Array<float, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr; | ||
Eigen::Array<float, Eigen::Dynamic, 1> new_bias = | ||
bias_arr - mean_arr * inv_std * scale_arr; | ||
EigenArrayMap<float> Y_arr(Y->mutable_data<float>(), sample_size, N * C); | ||
ConstEigenArrayMap<float> X_arr(X->data<float>(), sample_size, N * C); | ||
for (int nc = 0; nc < N * C; ++nc) { | ||
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace Lotus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/** | ||
* Copyright (c) 2016-present, Facebook, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
/* Modifications Copyright (c) Microsoft. */ | ||
|
||
#pragma once | ||
|
||
#include "core/common/common.h" | ||
#include "core/common/exceptions.h" | ||
#include "core/framework/op_kernel.h" | ||
#include "core/providers/cpu/nn/autopad_type.h" | ||
#include "core/framework/tensor.h" | ||
#include "core/util/math_cpuonly.h" | ||
|
||
namespace Lotus { | ||
|
||
template <typename T> | ||
class BatchNorm final : public OpKernel { | ||
public: | ||
BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { | ||
LOTUS_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK()); | ||
|
||
// keeping the below for reference. we don't need these for inference. | ||
//LOTUS_ENFORCE(op_kernel_info.GetAttr<int64_t>("is_test", &is_test_).IsOK()); | ||
//LOTUS_ENFORCE(op_kernel_info.GetAttr<float>("momentum", &momentum_).IsOK()); | ||
//LOTUS_ENFORCE(op_kernel_info.GetAttr<int64_t>("spatial", &spatial_).IsOK()); | ||
} | ||
|
||
Status ValidateInputs(const Tensor* X, | ||
const Tensor* scale, | ||
const Tensor* B, | ||
const Tensor* mean, | ||
const Tensor* var) const; | ||
|
||
Status compute(OpKernelContext* p_op_kernel_context) const override; | ||
|
||
private: | ||
float epsilon_; | ||
int64_t is_test_; // ignored in this implementation since we're doing inferencing only. | ||
float momentum_; // ignored in this implementation since we're doing inferencing only. | ||
int64_t spatial_; // ignored in this implementation since we're doing inferencing only. | ||
|
||
// defined as per spec and used for validation | ||
const int kNumInputXDimensions = 4; | ||
const int kNumInputScaleDimensions = 1; | ||
const int kNumInputBiasDimensions = 1; | ||
const int kNumInputMeanDimensions = 1; | ||
const int kNumInputVarianceDimensions = 1; | ||
}; | ||
} // namespace Lotus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.