Skip to content

Commit

Permalink
pnnx convert onnx multiheadattention (#5575)
Browse files Browse the repository at this point in the history
* pnnx convert onnx multiheadattention

* onnx reducemean reducesum

* reducemax reducemin reduceprod

* mask buggy torch

* avoid shadow output
  • Loading branch information
nihui authored Jul 11, 2024
1 parent 854678b commit c59885a
Show file tree
Hide file tree
Showing 16 changed files with 893 additions and 89 deletions.
25 changes: 21 additions & 4 deletions tools/pnnx/src/pass_level2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,15 +738,25 @@ static bool match_operator(const Operator* a, const Operator* b, std::map<std::s
return true;
}

static bool match(const Operator* anchor, const Operator* pattern, std::map<std::string, const Operator*>& matched_operators, std::map<std::string, const Operand*>& matched_inputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
static bool match(const Operator* anchor, const Operator* pattern, std::map<std::string, const Operator*>& matched_operators, std::map<std::string, const Operand*>& matched_inputs, std::map<std::string, const Operand*>& matched_outputs, std::map<std::string, Parameter>& captured_params, std::map<std::string, Attribute>& captured_attrs)
{
if (!match_operator(anchor, pattern, captured_params, captured_attrs))
return false;

for (size_t i = 0; i < pattern->outputs.size(); i++)
{
if (pattern->outputs[i]->consumers.size() == 1 && pattern->outputs[i]->consumers[0]->type == "pnnx.Output")
{
if (matched_outputs.find(pattern->outputs[i]->name) == matched_outputs.end())
{
matched_outputs[pattern->outputs[i]->name] = anchor->outputs[i];
}
else if (matched_outputs[pattern->outputs[i]->name] != anchor->outputs[i])
{
return false;
}
continue;
}

if (anchor->outputs[i]->consumers.size() != pattern->outputs[i]->consumers.size())
return false;
Expand All @@ -773,7 +783,7 @@ static bool match(const Operator* anchor, const Operator* pattern, std::map<std:
continue;
}

if (!match(anchor2, pattern2, matched_operators, matched_inputs, captured_params, captured_attrs))
if (!match(anchor2, pattern2, matched_operators, matched_inputs, matched_outputs, captured_params, captured_attrs))
return false;
}

Expand Down Expand Up @@ -838,9 +848,10 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde

std::map<std::string, const Operator*> matched_operators2;
std::map<std::string, const Operand*> matched_inputs2;
std::map<std::string, const Operand*> matched_outputs2;
std::map<std::string, Parameter> captured_params2;
std::map<std::string, Attribute> captured_attrs2;
if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2))
if (!match(anchor, pattern2, matched_operators2, matched_inputs2, matched_outputs2, captured_params2, captured_attrs2))
continue;

bool submatch_matched = true;
Expand Down Expand Up @@ -872,6 +883,13 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
matched_inputs[x.first] = x.second;
}
}
for (auto x : matched_outputs2)
{
if (matched_outputs.find(x.first) == matched_outputs.end())
{
matched_outputs[x.first] = x.second;
}
}
for (auto x : captured_params2)
{
captured_params[x.first] = x.second;
Expand All @@ -882,7 +900,6 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde
}

// match !
matched_outputs[pattern->inputs[i]->name] = anchor->outputs[i];
break;
}

Expand Down
26 changes: 26 additions & 0 deletions tools/pnnx/src/pass_level2/F_hardswish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,30 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_1, 9)

class F_hardswish_onnx_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input 0 1 input
prim::Constant op_0 0 1 20 value=3
aten::add op_1 2 1 input 20 8
aten::clamp op_2 1 1 8 9 max=6 min=0
prim::Constant op_3 0 1 23 value=6
aten::div op_4 2 1 9 23 10
aten::mul op_5 2 1 input 10 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.hardswish";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_2, 9)

} // namespace pnnx
49 changes: 49 additions & 0 deletions tools/pnnx/src/pass_level2/torch_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,53 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_1, 20)

class torch_max_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
ReduceMax op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.max";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("op_0.axes");
}
else
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)
{
dim[i] = i;
}
op->params["dim"] = dim;
}

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx, 20)

} // namespace pnnx
109 changes: 25 additions & 84 deletions tools/pnnx/src/pass_level2/torch_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class torch_mean_onnx : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
ReduceMean op_0 1 1 input out axes=%axes keepdims=%keepdims
ReduceMean op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -119,92 +119,33 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("axes");
op->params["keepdim"] = captured_params.at("keepdims").i ? true : false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 20)

class torch_mean_onnx_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
ReduceMean op_0 1 1 input out axes=%axes
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.mean";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("axes");
op->params["keepdim"] = true;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_1, 20)

class torch_mean_onnx_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dim
ReduceMean op_0 2 1 input dim out keepdims=%keepdims noop_with_empty_axes=0
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.mean";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["keepdim"] = captured_params.at("keepdims").i ? true : false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_2, 20)

class torch_mean_onnx_3 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
ReduceMean op_0 1 1 input out axes=%axes keepdims=%keepdims noop_with_empty_axes=0
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.mean";
}
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("op_0.axes");
}
else
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)
{
dim[i] = i;
}
op->params["dim"] = dim;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("axes");
op->params["keepdim"] = captured_params.at("keepdims").i ? true : false;
if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_3, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 20)

} // namespace pnnx
49 changes: 49 additions & 0 deletions tools/pnnx/src/pass_level2/torch_min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,53 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_1, 20)

class torch_min_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
ReduceMin op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.min";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("op_0.axes");
}
else
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)
{
dim[i] = i;
}
op->params["dim"] = dim;
}

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx, 20)

} // namespace pnnx
49 changes: 49 additions & 0 deletions tools/pnnx/src/pass_level2/torch_prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,53 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod, 20)

class torch_prod_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
ReduceProd op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.prod";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("op_0.axes");
}
else
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)
{
dim[i] = i;
}
op->params["dim"] = dim;
}

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod_onnx, 20)

} // namespace pnnx
Loading

0 comments on commit c59885a

Please sign in to comment.