Skip to content

Commit

Permalink
Prefer TailC format over others if I8 preciosn is used.
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Mar 25, 2021
1 parent 8834f04 commit 5fcd134
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
10 changes: 10 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_extension_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ PartialBlkDesc PartialBlkDesc::makeCBlocked(const InferenceEngine::SizeVector &d
return res;
}


PartialBlkDesc PartialBlkDesc::makeTailC(const InferenceEngine::SizeVector &dims) {
PartialBlkDesc res = makePlain(dims);
if (dims.size() > 2) {
auto itr = res.outer_order.begin() + 1;
std::rotate(itr, itr + 1, res.outer_order.end());
}
return res;
}

PartialBlkDesc PartialBlkDesc::extractFrom(const InferenceEngine::TensorDesc &desc) {
if (desc.getLayout() == InferenceEngine::ANY)
IE_THROW() << "Cannot extract partial blocked descriptor for `ANY` layout";
Expand Down
3 changes: 3 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class PartialBlkDesc {
/** Construct blocked Channel PartialBlkDesc based on dims information */
static PartialBlkDesc makeCBlocked(const InferenceEngine::SizeVector &dims, size_t block_size);

/** Construct per Channel PartialBlkDesc based on dims information */
static PartialBlkDesc makeTailC(const InferenceEngine::SizeVector &dims);

/** Compare operators. Allow to use it as key for std::map */
bool operator == (const PartialBlkDesc& it) const;
bool operator < (const PartialBlkDesc& it) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ void MKLDNNConcatNode::selectOptimalPrimitiveDescriptor() {
if (it.second > maxCount) {
maxCount = it.second;
convertTo = it.first;
} else if (it.second == maxCount) {
if (outputPrecision == Precision::I8 || outputPrecision == Precision::U8) {
if (it.first == PartialBlkDesc::makeTailC(getChildEdgeAt(0)->getDims().ToSizeVector())) {
convertTo = it.first;
}
}
}
}

Expand Down

0 comments on commit 5fcd134

Please sign in to comment.