Skip to content

Commit

Permalink
shape fine tune for unified 5d precedure
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed May 7, 2021
1 parent 57189ab commit b75ee1f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 31 deletions.
67 changes: 37 additions & 30 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator
Xbyak::Reg64 reg_d_bias = rdx;

Xbyak::Reg64 reg_load_table = r15;
Xbyak::Reg64 reg_load_store_mask = rcx;
Xbyak::Reg64 reg_load_store_mask = rbp;

Vmm vmm_val = Vmm(1);
Vmm vmm_mean = Vmm(0);
Expand Down Expand Up @@ -690,6 +690,38 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr<ngraph::Node>& op, const mkld
epsMode_ = INSIDE_SQRT;
acrossChannels_ = mvnOp->get_across_channels();
}

SizeVector inShape = getParentEdgeAt(0)->getDims().ToSizeVector();
transformTo5DCase(inShape);
}

void MKLDNNMVNNode::transformTo5DCase(const SizeVector& dims) {
switch (dims.size()) {
// for 1 and 2 rank, if across_channels is true, adjust shape to fully vectorize under unified 5d procedure.
// otherwise there are not enough data in spatial dimension to process in one kernel.
case 1 : // C
if (across_channels) {
shape5D = std::make_tuple(1, 1, 1, 1, dims[0]);
across_channels = false;
break;
} else {
shape5D = std::make_tuple(1, dims[0], 1, 1, 1);
break;
}
case 2 : // NC
if (across_channels) {
shape5D = std::make_tuple(1, dims[0], 1, dims[1], 1);
across_channels = false;
break;
} else {
shape5D = std::make_tuple(dims[0], dims[1], 1, 1, 1);
break;
}
case 3 : { shape5D = std::make_tuple(dims[0], dims[1], 1, dims[2], 1); break; }
case 4 : { shape5D = std::make_tuple(dims[0], dims[1], 1, dims[2], dims[3]); break; }
case 5 : { shape5D = std::make_tuple(dims[0], dims[1], dims[2], dims[3], dims[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << dims.size(); }
}
}

void MKLDNNMVNNode::getSupportedDescriptors() {
Expand Down Expand Up @@ -798,31 +830,6 @@ void MKLDNNMVNNode::initSupportedPrimitiveDescriptors() {
pushDesc(MKLDNNMemory::GetPlainFormat(getChildEdgeAt(0)->getDims()), impl_type);
}

std::tuple<size_t, size_t, size_t, size_t, size_t> MKLDNNMVNNode::get5dShapes(const SizeVector& dims) {
std::tuple<size_t, size_t, size_t, size_t, size_t> shapes;
switch (dims.size()) {
// for 1 and 2 rank, if across_channels is true, adjust shape to use unified 5d procedure to fully vectorize.
// otherwise there is only one data(in spatial dimension) to process in one kernel.
case 1 : // C
if (across_channels) {
shapes = std::make_tuple(1, 1, 1, 1, dims[0]); break;
} else {
shapes = std::make_tuple(1, dims[0], 1, 1, 1); break;
}
case 2 : // NC
if (across_channels) {
shapes = std::make_tuple(dims[0], 1, 1, 1, dims[1]); break;
} else {
shapes = std::make_tuple(dims[0], dims[1], 1, 1, 1); break;
}
case 3 : { shapes = std::make_tuple(dims[0], dims[1], 1, dims[2], 1); break; }
case 4 : { shapes = std::make_tuple(dims[0], dims[1], 1, dims[2], dims[3]); break; }
case 5 : { shapes = std::make_tuple(dims[0], dims[1], dims[2], dims[3], dims[4]); break; }
default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << dims.size(); }
}
return shapes;
}

void MKLDNNMVNNode::createPrimitive() {
auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr();
Expand All @@ -844,7 +851,7 @@ void MKLDNNMVNNode::createPrimitive() {
jcp.across_channels = acrossChannels_;
SizeVector in_dims = getParentEdgeAt(0)->getDims().ToSizeVector();
int N = 0;
std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = get5dShapes(in_dims);
std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = shape5D;

if (mayiuse(cpu::x64::avx512_common)) {
mvn_kernel.reset(new jit_uni_mvn_kernel_f32<cpu::x64::avx512_common>(jcp, *attr.get()));
Expand Down Expand Up @@ -938,7 +945,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, const Si
}

size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0;
std::tie(N, C, D, H, W) = get5dShapes(dims);
std::tie(N, C, D, H, W) = shape5D;

size_t C1 = H * W;
size_t C2 = C1 * D;
Expand Down Expand Up @@ -1066,7 +1073,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data, const Si
const float *src_data_ptr = reinterpret_cast<const float *>(src_data);
float *dst_data_ptr = reinterpret_cast<float *>(dst_data);
size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0;
std::tie(N, C, D, H, W) = get5dShapes(dims);
std::tie(N, C, D, H, W) = shape5D;

size_t C1 = H * W;
size_t C2 = C1 * D;
Expand Down Expand Up @@ -1169,7 +1176,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const Si
}

size_t N = 1; size_t C = 1; size_t D = 1; size_t H = 1; size_t W = 1;
std::tie(N, C, D, H, W) = get5dShapes(dims);
std::tie(N, C, D, H, W) = shape5D;

bool is_nhwc = false;
Layout layout = getParentEdgeAt(0)->getDesc().getLayout();
Expand Down
4 changes: 3 additions & 1 deletion inference-engine/src/mkldnn_plugin/nodes/mkldnn_mvn_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ class MKLDNNMVNNode : public MKLDNNNode {

void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false);

std::tuple<size_t, size_t, size_t, size_t, size_t> get5dShapes(const InferenceEngine::SizeVector& dims);
void transformTo5DCase(const InferenceEngine::SizeVector& dims);

std::tuple<size_t, size_t, size_t, size_t, size_t> shape5D;

bool acrossChannels_ = false;
bool normalizeVariance_ = true;
Expand Down

0 comments on commit b75ee1f

Please sign in to comment.