Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve mkldnn fallback. (#12663)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin authored and eric-haibin-lin committed Oct 10, 2018
1 parent 064c87c commit 443ded4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
22 changes: 15 additions & 7 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,13 @@ class StatefulComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
#if MXNET_USE_MKLDNN == 1
InvalidateOutputs(out_array, req);
CreateDefaultInputs(in_array, &in_array_fallback);
fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
return;
// TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs_.op, false)) {
CreateDefaultInputs(in_array, &in_array_fallback);
fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
return;
}
#endif
fcompute_(state_, op_ctx, in_array, req, out_array);
}
Expand All @@ -180,12 +184,14 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_;
}

explicit StatefulComputeExExecutor(const OpStatePtr& state,
explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
: state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
: attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {}

private:
NodeAttrs attrs_;
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
Expand Down Expand Up @@ -302,7 +308,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex, exec_type);
ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs, state,
fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
Expand All @@ -322,7 +329,8 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
ret[fwd_id].get()->state(), fcompute_ex, exec_type);
inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex,
exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
Expand Down
1 change: 1 addition & 0 deletions src/operator/quantization/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ by keep zero centered for the quantized value:
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
#endif
.set_attr<FCompute>("FCompute<cpu>", DequantizeCompute<cpu>)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ static void MKLDNNQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const Op
}

NNVM_REGISTER_OP(_contrib_quantized_pooling)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedPoolingForward);

} // namespace op
Expand Down
1 change: 1 addition & 0 deletions src/operator/quantization/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ where
.set_attr<nnvm::FInferType>("FInferType", QuantizeType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizeStorageType)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeCompute)
#endif
.set_attr<FCompute>("FCompute<cpu>", QuantizeCompute<cpu>)
Expand Down
1 change: 1 addition & 0 deletions src/operator/quantization/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ inference accuracy.
.set_attr<nnvm::FInferType>("FInferType", RequantizeType)
.set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNRequantizeForward)
#else
.set_attr<FCompute>("FCompute<cpu>", RequantizeForward<cpu>)
Expand Down

0 comments on commit 443ded4

Please sign in to comment.