Skip to content

Commit

Permalink
implement conv+clip fusion (#1412)
Browse files Browse the repository at this point in the history
This change implements Conv+Clip activation fusion for FusedConv and NCHWc convolutions. The Clip operation runs in the thread context that is producing the convolution output.
  • Loading branch information
tracysh authored Jul 17, 2019
1 parent d2cc086 commit 4383615
Show file tree
Hide file tree
Showing 14 changed files with 301 additions and 189 deletions.
26 changes: 22 additions & 4 deletions onnxruntime/contrib_ops/cpu/fused_activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,33 @@ common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION&
if (info.GetAttr<std::string>("activation", &activation_type).IsOK()) {
if (activation_type == "Relu") {
activation.ActivationKind = MlasReluActivation;
} else if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation.alpha = info.GetAttrOrDefault<float>("alpha", 0.01f);
} else if (activation_type == "Tanh") {
activation.ActivationKind = MlasTanhActivation;
} else if (activation_type == "Sigmoid") {
activation.ActivationKind = MlasLogisticActivation;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unimplemented activation: " + activation_type);
// The remaining activation types have additional parameters to be pulled out.
size_t activation_params_count;
if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation_params_count = 1;
} else if (activation_type == "Clip") {
activation.ActivationKind = MlasClipActivation;
activation_params_count = 2;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "unimplemented activation: " + activation_type);
}

std::vector<float> activation_params;
common::Status status = info.GetAttrs<float>("activation_params", activation_params);
if (!status.IsOK()) {
return status;
} else if (activation_params_count != activation_params.size()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "activation_params count mismatch");
}
for (size_t i = 0; i < activation_params_count; i++) {
activation.Parameters.Values[i] = activation_params[i];
}
}
}

Expand Down
28 changes: 19 additions & 9 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ void RegisterNchwcSchemas() {
AttributeProto::STRING,
OPTIONAL)
.Attr(
"alpha",
"activation_params",
"",
AttributeProto::FLOAT,
AttributeProto::FLOATS,
OPTIONAL)
.Input(0, "X", "", "T")
.Input(1, "W", "", "T")
Expand Down Expand Up @@ -862,10 +862,15 @@ activation.)DOC")
AttributeProto::INTS,
OPTIONAL)
.Attr(
"strides", "", AttributeProto::INTS, OPTIONAL)
.Attr("pads",
"",
AttributeProto::INTS, OPTIONAL)
"strides",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"pads",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"group",
"",
Expand All @@ -877,9 +882,9 @@ activation.)DOC")
AttributeProto::STRING,
OPTIONAL)
.Attr(
"alpha",
"activation_params",
"",
AttributeProto::FLOAT,
AttributeProto::FLOATS,
OPTIONAL)
.Input(
0,
Expand All @@ -891,7 +896,12 @@ activation.)DOC")
"W",
"",
"T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Input(
2,
"B",
"",
"T",
OpSchema::Optional)
.Output(
0,
"Y",
Expand Down
15 changes: 12 additions & 3 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,30 @@ enum MLAS_ACTIVATION_KIND {
MlasLeakyReluActivation,
MlasTanhActivation,
MlasLogisticActivation,
MlasClipActivation,
};

struct MLAS_ACTIVATION {
MLAS_ACTIVATION_KIND ActivationKind;
float alpha;
union {
struct {
float alpha;
} LeakyRelu;
struct {
float minimum;
float maximum;
} Clip;
float Values[2];
} Parameters;
};

void
MLASCALL
MlasActivation(
const MLAS_ACTIVATION* Activation,
const float* Input,
float* Buffer,
const float* Bias,
size_t M,
float* Output,
size_t N,
size_t ldc
);
Expand Down
Loading

0 comments on commit 4383615

Please sign in to comment.