From 6c9ca56a3894c02a9feafb49579d8d1382a731bb Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 11 Aug 2020 11:07:26 -0700 Subject: [PATCH] Backporting backward inference from 2.x #18348 and #18378 Signed-off-by: Serge Panev --- src/operator/contrib/batch_norm_relu.cc | 41 +++++++++++++++++-------- src/operator/nn/batch_norm.cc | 40 ++++++++++++++++-------- src/operator/nn/convolution.cc | 13 ++++++-- src/operator/nn/deconvolution.cc | 18 +++++++++-- src/operator/nn/group_norm.cc | 8 ++--- src/operator/nn/layer_norm.cc | 7 +++-- src/operator/nn/pooling.cc | 7 +++-- src/operator/softmax_output.cc | 18 +++++++++-- src/operator/tensor/matrix_op-inl.h | 22 ++++++++++--- 9 files changed, 125 insertions(+), 49 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index 14452cc96729..51aa4c5ae25c 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; CHECK_EQ(out_shape->size(), 4U); const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } const size_t channelAxis = static_cast(param.axis < 0 ? static_cast(dshape.ndim()) + param.axis @@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs, const int channelCount = dshape[channelAxis]; - if (!mxnet::ndim_is_known(dshape)) { - return false; - } - in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean @@ -84,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 4; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + // Input type is defined but output type is not: forward inference + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -101,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 4; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 3e36559c0a7c..3bc16597b7d8 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -365,6 +365,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; CHECK_EQ(out_shape->size(), 3U); const mxnet::TShape &dshape = in_shape->at(batchnorm::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } const size_t channelAxis = static_cast(param.axis < 0 ? static_cast(dshape.ndim()) + param.axis @@ -373,10 +376,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, const int channelCount = dshape[channelAxis]; - if (!mxnet::ndim_is_known(dshape)) { - return false; - } - in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean @@ -394,14 +393,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); - const int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + const size_t n_out = 3; // For float16 input type beta, gamma, mean, and average are stored in float32. // For other input types, these parameters have the same type as input // NOTE: This requirement is from cuDNN (v. 4 and 5) int dtype_param; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + int dtype = (*in_type)[0]; + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + } + } else { + // Input type is defined but output type is not: forward inference + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype); + for (size_t i = 1; i < n_out; ++i) { + out_type->push_back(dtype_param); + } + } std::vector args{"data", "gamma", "beta", "mean", "var"}; CHECK_LE(in_type->size(), args.size()); for (size_t i = 1; i < in_type->size(); ++i) { @@ -411,12 +431,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]); } } - const size_t n_out = 3; - out_type->clear(); - out_type->push_back(dtype); - for (size_t i = 1; i < n_out; ++i) { - out_type->push_back(dtype_param); - } return true; } diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 8ff5ea75d5f7..3ebb67ad0aa0 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -285,7 +285,16 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, const ConvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + return false; + } else { + dtype = (*out_type)[0]; + } + } else { + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -293,8 +302,6 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index cd22aced0d03..08d6306730ef 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -332,7 +332,21 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, const DeconvolutionParam& param_ = nnvm::get(attrs.parsed); CHECK_GE(in_type->size(), 1U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + } + } else { + // Input type is defined but output type is not: forward inference + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -340,8 +354,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc index 6b8fe9bbd4c9..a92ac3113082 100644 --- a/src/operator/nn/group_norm.cc +++ b/src/operator/nn/group_norm.cc @@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(groupnorm::kData); - CHECK_GE(dshape.ndim(), 3U); - const int num_groups = param.num_groups; - CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups"; - if (!mxnet::ndim_is_known(dshape)) { return false; } + CHECK_GE(dshape.ndim(), 3U); + const int num_groups = param.num_groups; + CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups"; + in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups)); in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups)); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index d385b93e9cff..c3ccd0d7a6bc 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(layernorm::kData); + if (!mxnet::ndim_is_known(dshape)) { + return false; + } + int axis = GetRealAxis(param.axis, dshape.ndim()); CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; const int channelCount = dshape[axis]; - if (!mxnet::ndim_is_known(dshape)) { - return false; - } SHAPE_ASSIGN_CHECK(*in_shape, layernorm::kGamma, mxnet::TShape(Shape1(channelCount))); diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 03787f42b038..c81cae358422 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -95,10 +95,14 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, mxnet::ShapeVector *out_shape) { const PoolingParam ¶m = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); + const mxnet::TShape &dshape = (*in_shape)[0]; + if (!mxnet::ndim_is_known(dshape)) { + return false; + } if (param.pool_type == pool_enum::kLpPooling) { CHECK(param.p_value.has_value()); } - const mxnet::TShape &dshape = (*in_shape)[0]; + if (param.pooling_convention == pool_enum::kSame) { CHECK_EQ(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)" @@ -114,7 +118,6 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " << " Or 5D in (batch, channel, d, y, x)"; - if (!mxnet::ndim_is_known(dshape)) return false; int layout = param.GetLayout(dshape.ndim()); if (param.global_pool) { mxnet::TShape oshape = dshape; diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc index 13bb647f9d43..d87b78145e9e 100644 --- a/src/operator/softmax_output.cc +++ b/src/operator/softmax_output.cc @@ -66,7 +66,21 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, std::vector *out_type) { CHECK_EQ(in_type->size(), 2U); int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; + if (type_is_none(dtype)) { + // Input type is undefined, we try backward inference + if (out_type->size() == 0 || type_is_none((*out_type)[0])) { + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; + } else { + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + } + } else { + // Input type is defined but output type is not: forward inference + out_type->clear(); + out_type->push_back(dtype); + } for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { (*in_type)[i] = dtype; @@ -74,8 +88,6 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs, UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]); } } - out_type->clear(); - out_type->push_back(dtype); return true; } diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index fa7b8a10b212..217bf10398ad 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -455,9 +455,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; mxnet::TShape& out_shp = (*out_attrs)[0]; - CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - if (shp.ndim() == -1 && out_shp.ndim() == -1) + if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp)) return false; // none of the shapes is known + CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; if (out_shp.ndim() >= 0 && shp.ndim() >= 0) CHECK_EQ(out_shp.ndim(), shp.ndim()); mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); @@ -506,12 +506,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, const ExpandDimParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) { + mxnet::TShape& ishape = (*in_attrs)[0]; + mxnet::TShape& oshape = (*out_attrs)[0]; + if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) { return false; } - mxnet::TShape& ishape = (*in_attrs)[0]; - mxnet::TShape& oshape = (*out_attrs)[0]; int indim = ishape.ndim(); bool unknown_ishape = false; if (-1 == indim) { @@ -1434,6 +1434,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& ishape = (*in_attrs)[0]; mxnet::TShape& from_shape = (*in_attrs)[1]; + if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) { + return false; + } if (param.axes.ndim() == 0) { CHECK_EQ(ishape.ndim(), from_shape.ndim()) << "By default slice_axis performs slice on all axes, but ndim mismatch " @@ -1727,6 +1730,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& ishape = (*in_attrs)[0]; + if (!mxnet::ndim_is_known(ishape)) { + return false; + } int repeats = 0; dmlc::optional axisOpt; GetRepeatParams(param, ishape, &repeats, &axisOpt); @@ -2395,6 +2401,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs, mxnet::TShape expected_out(4, -1); mxnet::TShape& in_shape = in_attrs->at(0); + if (!mxnet::ndim_is_known(in_shape)) { + return false; + } int block = param.block_size; CHECK_NE(block, 0) << "block_size must be a positive integer value"; CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0"; @@ -2559,6 +2568,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs, mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); mxnet::TShape& in_shape = in_attrs->at(0); + if (!mxnet::ndim_is_known(in_shape)) { + return false; + } int block = param.block_size; CHECK_NE(block, 0) << "block_size must be a positive integer value"; CHECK_NE(in_shape[0], 0)