Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement conv+clip fusion #1412

Merged
merged 7 commits into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions onnxruntime/contrib_ops/cpu/fused_activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,33 @@ common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION&
if (info.GetAttr<std::string>("activation", &activation_type).IsOK()) {
if (activation_type == "Relu") {
activation.ActivationKind = MlasReluActivation;
} else if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation.alpha = info.GetAttrOrDefault<float>("alpha", 0.01f);
} else if (activation_type == "Tanh") {
activation.ActivationKind = MlasTanhActivation;
} else if (activation_type == "Sigmoid") {
activation.ActivationKind = MlasLogisticActivation;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unimplemented activation: " + activation_type);
// The remaining activation types have additional parameters to be pulled out.
size_t activation_params_count;
if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation_params_count = 1;
} else if (activation_type == "Clip") {
activation.ActivationKind = MlasClipActivation;
activation_params_count = 2;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "unimplemented activation: " + activation_type);
}

std::vector<float> activation_params;
common::Status status = info.GetAttrs<float>("activation_params", activation_params);
if (!status.IsOK()) {
return status;
} else if (activation_params_count != activation_params.size()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "activation_params count mismatch");
}
for (size_t i = 0; i < activation_params_count; i++) {
activation.Parameters.Values[i] = activation_params[i];
}
}
}

Expand Down
28 changes: 19 additions & 9 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ void RegisterNchwcSchemas() {
AttributeProto::STRING,
OPTIONAL)
.Attr(
"alpha",
"activation_params",
"",
AttributeProto::FLOAT,
AttributeProto::FLOATS,
OPTIONAL)
.Input(0, "X", "", "T")
.Input(1, "W", "", "T")
Expand Down Expand Up @@ -862,10 +862,15 @@ activation.)DOC")
AttributeProto::INTS,
OPTIONAL)
.Attr(
"strides", "", AttributeProto::INTS, OPTIONAL)
.Attr("pads",
"",
AttributeProto::INTS, OPTIONAL)
"strides",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"pads",
"",
AttributeProto::INTS,
OPTIONAL)
.Attr(
"group",
"",
Expand All @@ -877,9 +882,9 @@ activation.)DOC")
AttributeProto::STRING,
OPTIONAL)
.Attr(
"alpha",
"activation_params",
"",
AttributeProto::FLOAT,
AttributeProto::FLOATS,
OPTIONAL)
.Input(
0,
Expand All @@ -891,7 +896,12 @@ activation.)DOC")
"W",
"",
"T")
.Input(2, "B", "", "T", OpSchema::Optional)
.Input(
2,
"B",
"",
"T",
OpSchema::Optional)
.Output(
0,
"Y",
Expand Down
15 changes: 12 additions & 3 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,30 @@ enum MLAS_ACTIVATION_KIND {
MlasLeakyReluActivation,
MlasTanhActivation,
MlasLogisticActivation,
MlasClipActivation,
};

struct MLAS_ACTIVATION {
MLAS_ACTIVATION_KIND ActivationKind;
float alpha;
union {
struct {
float alpha;
} LeakyRelu;
struct {
float minimum;
float maximum;
} Clip;
float Values[2];
} Parameters;
};

void
MLASCALL
MlasActivation(
const MLAS_ACTIVATION* Activation,
const float* Input,
float* Buffer,
const float* Bias,
size_t M,
float* Output,
size_t N,
size_t ldc
);
Expand Down
Loading