From 26f44b71d8de84bbc88af496ae0aeb7ce535312d Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Wed, 10 Jun 2020 10:41:50 -0700 Subject: [PATCH] Add backward Type inference to main NN operators (#18378) * Add backward Type inference to main DNN operators Signed-off-by: Serge Panev * Add comments Signed-off-by: Serge Panev --- src/operator/contrib/batch_norm_relu.cc | 34 ++++++++++++++++++------- src/operator/nn/batch_norm.cc | 33 +++++++++++++++++------- src/operator/nn/convolution.cc | 13 +++++++--- src/operator/nn/deconvolution.cc | 18 ++++++++++--- src/operator/softmax_output.cc | 18 ++++++++++--- 5 files changed, 89 insertions(+), 27 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index 0bb2f0b43693..51aa4c5ae25c 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -83,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 4; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + // Input type is defined but output type is not: forward inference + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -100,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 4; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index d4b03ae3fc17..8dbd27195bb5 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -392,14 +392,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 3; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + // Input type is defined but output type is not: forward inference + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -409,12 +430,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 3; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 8ff5ea75d5f7..3ebb67ad0aa0 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -285,7 +285,16 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -293,8 +302,6 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index cd22aced0d03..08d6306730ef 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -332,7 +332,21 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, const DeconvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + } + } else { + // Input type is defined but output type is not: forward inference + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -340,8 +354,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 13bb647f9d43..d87b78145e9e 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -66,7 +66,21 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, std::vector *out_type) { CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + } + } else { + // Input type is defined but output type is not: forward inference + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -74,8 +88,6 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; }