Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add backward Type inference to main DNN operators
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed May 21, 2020
1 parent 67b5d31 commit b9bb2ed
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 27 deletions.
28 changes: 19 additions & 9 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,30 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
const int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
const size_t n_out = 4;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
int dtype = (*in_type)[0];
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
(*in_type)[0] = dtype;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
}
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
}
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
Expand All @@ -101,12 +117,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
const size_t n_out = 4;
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
return true;
}

Expand Down
28 changes: 19 additions & 9 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,30 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int> *out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
const int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
const size_t n_out = 3;
// For float16 input type beta, gamma, mean, and average are stored in float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
int dtype = (*in_type)[0];
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
(*in_type)[0] = dtype;
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
}
} else {
MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
}
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
Expand All @@ -369,12 +385,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
const size_t n_out = 3;
out_type->clear();
out_type->push_back(dtype);
for (size_t i = 1; i < n_out; ++i) {
out_type->push_back(dtype_param);
}
return true;
}

Expand Down
13 changes: 10 additions & 3 deletions src/operator/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,23 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
}
} else {
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
13 changes: 10 additions & 3 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,23 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
const DeconvolutionParam& param_ = nnvm::get<DeconvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
}
} else {
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down
13 changes: 10 additions & 3 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,23 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_type) {
CHECK_EQ(in_type->size(), 2U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
if (type_is_none(dtype)) {
if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
return false;
} else {
dtype = (*out_type)[0];
}
} else {
out_type->clear();
out_type->push_back(dtype);
}
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

Expand Down

0 comments on commit b9bb2ed

Please sign in to comment.