Skip to content

Commit

Permalink
Merged PR 1136: Add BatchNormalization operator.
Browse files Browse the repository at this point in the history
Related work items: #57
  • Loading branch information
Pranav Sharma authored and Pranav Sharma committed Mar 29, 2018
1 parent d32b606 commit cfe526a
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 23 deletions.
99 changes: 99 additions & 0 deletions lotus/core/providers/cpu/nn/batch_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include "core/providers/cpu/nn/batch_norm.h"

namespace Lotus {
// spec: https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization
REGISTER_KERNEL(KernelDef("BatchNormalization")
.Domain(LotusIR::kOnnxDomain)
// This operator is used if you are using version 6 of the default ONNX operator
// set until the next BC-breaking change to this operator
.SinceVersion(6, 7)
.Provider(LotusIR::kCpuExecutionProvider)
.TypeConstraint("X", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("scale", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("B", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("mean", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("var", DataTypeImpl::GetTensorType<float>()),
BatchNorm<float>);

template <>
Status BatchNorm<float>::ValidateInputs(const Tensor* X,
const Tensor* scale,
const Tensor* B,
const Tensor* mean,
const Tensor* var) const {
if (X->shape().NumDimensions() != kNumInputXDimensions) {
std::ostringstream ostr;
ostr << "Invalid input X: NumDimensions() != " << kNumInputXDimensions;
return Status(LOTUS, INVALID_ARGUMENT, ostr.str());
}
if (scale->shape().NumDimensions() != kNumInputScaleDimensions) {
std::ostringstream ostr;
ostr << "Invalid input scale: NumDimensions() != " << kNumInputScaleDimensions;
return Status(LOTUS, INVALID_ARGUMENT, ostr.str());
}
if (B->shape().NumDimensions() != kNumInputBiasDimensions) {
std::ostringstream ostr;
ostr << "Invalid input B: NumDimensions() != " << kNumInputBiasDimensions;
return Status(LOTUS, INVALID_ARGUMENT, ostr.str());
}
if (mean->shape().NumDimensions() != kNumInputMeanDimensions) {
std::ostringstream ostr;
ostr << "Invalid input mean: NumDimensions() != " << kNumInputMeanDimensions;
return Status(LOTUS, INVALID_ARGUMENT, ostr.str());
}
if (var->shape().NumDimensions() != kNumInputVarianceDimensions) {
std::ostringstream ostr;
ostr << "Invalid input var: NumDimensions() != " << kNumInputVarianceDimensions;
return Status(LOTUS, INVALID_ARGUMENT, ostr.str());
}

return Status::OK();
}

template <>
Status BatchNorm<float>::compute(OpKernelContext* p_op_kernel_context) const {
const Tensor* X = p_op_kernel_context->input<Tensor>(0);
const Tensor* scale = p_op_kernel_context->input<Tensor>(1);
const Tensor* B = p_op_kernel_context->input<Tensor>(2);
const Tensor* mean = p_op_kernel_context->input<Tensor>(3);
const Tensor* var = p_op_kernel_context->input<Tensor>(4);

LOTUS_RETURN_IF_ERROR(ValidateInputs(X, scale, B, mean, var));

const TensorShape& x_shape = X->shape();
Tensor* Y = p_op_kernel_context->output(0, x_shape);

const size_t N = x_shape[0];
const size_t C = x_shape[1]; // assume NCHW as per the spec
const size_t H = x_shape[2];
const size_t W = x_shape[3];

const size_t sample_size = H * W;

ConstEigenVectorArrayMap<float> scale_arr(scale->data<float>(), C);
ConstEigenVectorArrayMap<float> bias_arr(B->data<float>(), C);

// Regardless of training or testing, we will apply the estimated mean
// and standard deviation to the input. For testing, they are
// specified directly by the input, and for training, they are computed
// by the op.
Eigen::Array<float, Eigen::Dynamic, 1> inv_std(C);
ConstEigenVectorArrayMap<float> var_arr(var->data<float>(), C);
inv_std = (var_arr + epsilon_).sqrt().inverse();
ConstEigenVectorArrayMap<float> mean_arr(mean->data<float>(), C);
// We can fuse the output computation as follows:
// ((x - est_mean) * (inv_var) * scale + bias
// to
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
Eigen::Array<float, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
Eigen::Array<float, Eigen::Dynamic, 1> new_bias =
bias_arr - mean_arr * inv_std * scale_arr;
EigenArrayMap<float> Y_arr(Y->mutable_data<float>(), sample_size, N * C);
ConstEigenArrayMap<float> X_arr(X->data<float>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
Y_arr.col(nc) = X_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
}

return Status::OK();
}
} // namespace Lotus
62 changes: 62 additions & 0 deletions lotus/core/providers/cpu/nn/batch_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Modifications Copyright (c) Microsoft. */

#pragma once

#include "core/common/common.h"
#include "core/common/exceptions.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/nn/autopad_type.h"
#include "core/framework/tensor.h"
#include "core/util/math_cpuonly.h"

namespace Lotus {

template <typename T>
class BatchNorm final : public OpKernel {
public:
BatchNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) {
LOTUS_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());

// keeping the below for reference. we don't need these for inference.
//LOTUS_ENFORCE(op_kernel_info.GetAttr<int64_t>("is_test", &is_test_).IsOK());
//LOTUS_ENFORCE(op_kernel_info.GetAttr<float>("momentum", &momentum_).IsOK());
//LOTUS_ENFORCE(op_kernel_info.GetAttr<int64_t>("spatial", &spatial_).IsOK());
}

Status ValidateInputs(const Tensor* X,
const Tensor* scale,
const Tensor* B,
const Tensor* mean,
const Tensor* var) const;

Status compute(OpKernelContext* p_op_kernel_context) const override;

private:
float epsilon_;
int64_t is_test_; // ignored in this implementation since we're doing inferencing only.
float momentum_; // ignored in this implementation since we're doing inferencing only.
int64_t spatial_; // ignored in this implementation since we're doing inferencing only.

// defined as per spec and used for validation
const int kNumInputXDimensions = 4;
const int kNumInputScaleDimensions = 1;
const int kNumInputBiasDimensions = 1;
const int kNumInputMeanDimensions = 1;
const int kNumInputVarianceDimensions = 1;
};
} // namespace Lotus
4 changes: 2 additions & 2 deletions lotus/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ TEST(AllocationPlannerTest, ChainNoShapeTest) {

SequentialExecutionPlan plan;
auto status = SequentialPlanner::CreatePlan(state, &plan);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

std::vector<AllocKind> expected_alloc({AllocKind::kAllocateStatically, AllocKind::kAllocate, AllocKind::kAllocate, AllocKind::kAllocate});
AllocationPlanTestUtility::CheckAllocationKind(plan, expected_alloc);
Expand Down Expand Up @@ -167,7 +167,7 @@ TEST(AllocationPlannerTest, InputOutputTest) {

SequentialExecutionPlan plan;
auto status = SequentialPlanner::CreatePlan(state, &plan);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

// X1: kPreExisting, X2: kPreExisting, Y1: kAllocate, Y2: kAllocate
std::vector<AllocKind> expected_alloc({AllocKind::kPreExisting, AllocKind::kPreExisting, AllocKind::kAllocate, AllocKind::kAllocate});
Expand Down
2 changes: 1 addition & 1 deletion lotus/test/lib/threadpool_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TEST(ThreadPool, DoWork) {
}
}
for (int i = 0; i < kWorkItems; i++) {
ASSERT_TRUE(work[i]);
EXPECT_TRUE(work[i]);
}
}
}
Expand Down
Loading

0 comments on commit cfe526a

Please sign in to comment.