diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp index edbd4e426b1fd1..36d9942b09932f 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp @@ -506,7 +506,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator Xbyak::Reg64 reg_d_bias = rdx; Xbyak::Reg64 reg_load_table = r15; - Xbyak::Reg64 reg_load_store_mask = rcx; + Xbyak::Reg64 reg_load_store_mask = rbp; Vmm vmm_val = Vmm(1); Vmm vmm_mean = Vmm(0); @@ -672,6 +672,7 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr& op, const mkld IE_THROW(NotImplemented) << errorMessage; } + const ngraph::Shape& inDataShape = op->input_value(0).get_shape(); if (auto mvnOp = ngraph::as_type_ptr(op)) { normalizeVariance_ = mvnOp->get_normalize_variance(); epsValue_ = mvnOp->get_eps(); @@ -681,7 +682,7 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr& op, const mkld } acrossChannels_ = false; - const auto& inDataShapeSize = op->input_value(0).get_shape().size(); + const auto& inDataShapeSize = inDataShape.size(); if (inDataShapeSize == mvnOp->input_value(1).get_shape()[0] + 1 || inDataShapeSize == 1) acrossChannels_ = true; } else if (auto mvnOp = ngraph::as_type_ptr(op)) { @@ -690,6 +691,37 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr& op, const mkld epsMode_ = INSIDE_SQRT; acrossChannels_ = mvnOp->get_across_channels(); } + + transformTo5DCase(inDataShape); +} + +void MKLDNNMVNNode::transformTo5DCase(const ngraph::Shape& shape) { + switch (shape.size()) { + // for 1 and 2 rank, if acrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure. + // otherwise there are not enough data in spatial dimension to process in one kernel. + case 1 : // C + if (acrossChannels_) { + shape5D = std::make_tuple(1, 1, 1, 1, shape[0]); + acrossChannels_ = false; + break; + } else { + shape5D = std::make_tuple(1, shape[0], 1, 1, 1); + break; + } + case 2 : // NC + if (acrossChannels_) { + shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1); + acrossChannels_ = false; + break; + } else { + shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1); + break; + } + case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; } + case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; } + case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; } + default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); } + } } void MKLDNNMVNNode::getSupportedDescriptors() { @@ -798,19 +830,6 @@ void MKLDNNMVNNode::initSupportedPrimitiveDescriptors() { pushDesc(MKLDNNMemory::GetPlainFormat(getChildEdgeAt(0)->getDims()), impl_type); } -std::tuple MKLDNNMVNNode::get5dShapes(const SizeVector& dims) { - std::tuple shapes; - switch (dims.size()) { - case 1 : { shapes = std::make_tuple(1, dims[0], 1, 1, 1); break; } - case 2 : { shapes = std::make_tuple(dims[0], dims[1], 1, 1, 1); break; } - case 3 : { shapes = std::make_tuple(dims[0], dims[1], 1, dims[2], 1); break; } - case 4 : { shapes = std::make_tuple(dims[0], dims[1], 1, dims[2], dims[3]); break; } - case 5 : { shapes = std::make_tuple(dims[0], dims[1], dims[2], dims[3], dims[4]); break; } - default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << dims.size(); } - } - return shapes; -} - void MKLDNNMVNNode::createPrimitive() { auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr(); @@ -832,7 +851,7 @@ void MKLDNNMVNNode::createPrimitive() { jcp.across_channels = acrossChannels_; SizeVector in_dims = getParentEdgeAt(0)->getDims().ToSizeVector(); int N = 0; - std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = get5dShapes(in_dims); + std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = shape5D; if (mayiuse(cpu::x64::avx512_common)) { mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); @@ -926,7 +945,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, const Si } size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0; - std::tie(N, C, D, H, W) = get5dShapes(dims); + std::tie(N, C, D, H, W) = shape5D; size_t C1 = H * W; size_t C2 = C1 * D; @@ -1054,7 +1073,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data, const Si const float *src_data_ptr = reinterpret_cast(src_data); float *dst_data_ptr = reinterpret_cast(dst_data); size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0; - std::tie(N, C, D, H, W) = get5dShapes(dims); + std::tie(N, C, D, H, W) = shape5D; size_t C1 = H * W; size_t C2 = C1 * D; @@ -1157,7 +1176,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const Si } size_t N = 1; size_t C = 1; size_t D = 1; size_t H = 1; size_t W = 1; - std::tie(N, C, D, H, W) = get5dShapes(dims); + std::tie(N, C, D, H, W) = shape5D; bool is_nhwc = false; Layout layout = getParentEdgeAt(0)->getDesc().getLayout(); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h index c23da5e0c11fbb..515d17bcb5bc3a 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h @@ -104,7 +104,9 @@ class MKLDNNMVNNode : public MKLDNNNode { void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false); - std::tuple get5dShapes(const InferenceEngine::SizeVector& dims); + void transformTo5DCase(const ngraph::Shape& shape); + + std::tuple shape5D; bool acrossChannels_ = false; bool normalizeVariance_ = true;