From c59885aeac6cec0dbfa010efc0b5c25bed5208b7 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 11 Jul 2024 15:53:27 +0800 Subject: [PATCH] pnnx convert onnx multiheadattention (#5575) * pnnx convert onnx multiheadattention * onnx reducemean reducesum * reducemax reducemin reduceprod * mask buggy torch * avoid shadow output --- tools/pnnx/src/pass_level2.cpp | 25 +- tools/pnnx/src/pass_level2/F_hardswish.cpp | 26 + tools/pnnx/src/pass_level2/torch_max.cpp | 49 ++ tools/pnnx/src/pass_level2/torch_mean.cpp | 109 +---- tools/pnnx/src/pass_level2/torch_min.cpp | 49 ++ tools/pnnx/src/pass_level2/torch_prod.cpp | 49 ++ tools/pnnx/src/pass_level2/torch_sum.cpp | 49 ++ .../pass_level5/fuse_multiheadattention.cpp | 454 ++++++++++++++++++ .../pass_onnx/fuse_constant_as_attribute.cpp | 4 + tools/pnnx/tests/onnx/CMakeLists.txt | 10 +- tools/pnnx/tests/onnx/test_F_hardshrink.py | 4 + tools/pnnx/tests/onnx/test_F_hardsigmoid.py | 4 + tools/pnnx/tests/onnx/test_F_hardswish.py | 4 + tools/pnnx/tests/onnx/test_nn_Hardshrink.py | 4 + tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py | 4 + .../tests/onnx/test_nn_MultiheadAttention.py | 138 ++++++ 16 files changed, 893 insertions(+), 89 deletions(-) create mode 100644 tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py diff --git a/tools/pnnx/src/pass_level2.cpp b/tools/pnnx/src/pass_level2.cpp index 107613d3dea..bc7e51b8d5d 100644 --- a/tools/pnnx/src/pass_level2.cpp +++ b/tools/pnnx/src/pass_level2.cpp @@ -738,7 +738,7 @@ static bool match_operator(const Operator* a, const Operator* b, std::map& matched_operators, std::map& matched_inputs, std::map& captured_params, std::map& captured_attrs) +static bool match(const Operator* anchor, const Operator* pattern, std::map& matched_operators, std::map& matched_inputs, std::map& matched_outputs, std::map& captured_params, std::map& captured_attrs) { if (!match_operator(anchor, pattern, captured_params, captured_attrs)) return false; @@ -746,7 +746,17 @@ static bool match(const Operator* anchor, const Operator* pattern, std::mapoutputs.size(); i++) { if (pattern->outputs[i]->consumers.size() == 1 && pattern->outputs[i]->consumers[0]->type == "pnnx.Output") + { + if (matched_outputs.find(pattern->outputs[i]->name) == matched_outputs.end()) + { + matched_outputs[pattern->outputs[i]->name] = anchor->outputs[i]; + } + else if (matched_outputs[pattern->outputs[i]->name] != anchor->outputs[i]) + { + return false; + } continue; + } if (anchor->outputs[i]->consumers.size() != pattern->outputs[i]->consumers.size()) return false; @@ -773,7 +783,7 @@ static bool match(const Operator* anchor, const Operator* pattern, std::map matched_operators2; std::map matched_inputs2; + std::map matched_outputs2; std::map captured_params2; std::map captured_attrs2; - if (!match(anchor, pattern2, matched_operators2, matched_inputs2, captured_params2, captured_attrs2)) + if (!match(anchor, pattern2, matched_operators2, matched_inputs2, matched_outputs2, captured_params2, captured_attrs2)) continue; bool submatch_matched = true; @@ -872,6 +883,13 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde matched_inputs[x.first] = x.second; } } + for (auto x : matched_outputs2) + { + if (matched_outputs.find(x.first) == matched_outputs.end()) + { + matched_outputs[x.first] = x.second; + } + } for (auto x : captured_params2) { captured_params[x.first] = x.second; @@ -882,7 +900,6 @@ void pnnx_graph_rewrite(Graph& graph, const GraphRewriterPass* pass, int& opinde } // match ! - matched_outputs[pattern->inputs[i]->name] = anchor->outputs[i]; break; } diff --git a/tools/pnnx/src/pass_level2/F_hardswish.cpp b/tools/pnnx/src/pass_level2/F_hardswish.cpp index 7d44efedf6b..caa724f55a7 100644 --- a/tools/pnnx/src/pass_level2/F_hardswish.cpp +++ b/tools/pnnx/src/pass_level2/F_hardswish.cpp @@ -317,4 +317,30 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_1, 9) +class F_hardswish_onnx_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input +prim::Constant op_0 0 1 20 value=3 +aten::add op_1 2 1 input 20 8 +aten::clamp op_2 1 1 8 9 max=6 min=0 +prim::Constant op_3 0 1 23 value=6 +aten::div op_4 2 1 9 23 10 +aten::mul op_5 2 1 input 10 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.hardswish"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_hardswish_onnx_2, 9) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 3448b5b939f..68479b85d5b 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -60,4 +60,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_1, 20) +class torch_max_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceMax op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.max"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index 18944fac345..39f4423243d 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -107,7 +107,7 @@ class torch_mean_onnx : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -ReduceMean op_0 1 1 input out axes=%axes keepdims=%keepdims +ReduceMean op_0 1 1 input out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -119,92 +119,33 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - op->params["dim"] = captured_params.at("axes"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 20) - -class torch_mean_onnx_1 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -ReduceMean op_0 1 1 input out axes=%axes -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "torch.mean"; - } - - void write(Operator* op, const std::map& captured_params) const - { - op->params["dim"] = captured_params.at("axes"); - op->params["keepdim"] = true; - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_1, 20) - -class torch_mean_onnx_2 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 dim -ReduceMean op_0 2 1 input dim out keepdims=%keepdims noop_with_empty_axes=0 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "torch.mean"; - } - - void write(Operator* op, const std::map& captured_params) const - { - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_2, 20) - -class torch_mean_onnx_3 : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -3 2 -pnnx.Input input 0 1 input -ReduceMean op_0 1 1 input out axes=%axes keepdims=%keepdims noop_with_empty_axes=0 -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* type_str() const - { - return "torch.mean"; - } + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } - void write(Operator* op, const std::map& captured_params) const - { - op->params["dim"] = captured_params.at("axes"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } } }; -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx_3, 20) +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 20) } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index 119b442e11d..c5e48bbc64b 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -60,4 +60,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_1, 20) +class torch_min_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceMin op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_prod.cpp b/tools/pnnx/src/pass_level2/torch_prod.cpp index bd3e49b8cb1..7f15c2ba88a 100644 --- a/tools/pnnx/src/pass_level2/torch_prod.cpp +++ b/tools/pnnx/src/pass_level2/torch_prod.cpp @@ -40,4 +40,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod, 20) +class torch_prod_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceProd op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.prod"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_prod_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_sum.cpp b/tools/pnnx/src/pass_level2/torch_sum.cpp index 730ffcce3b8..51803d5d01e 100644 --- a/tools/pnnx/src/pass_level2/torch_sum.cpp +++ b/tools/pnnx/src/pass_level2/torch_sum.cpp @@ -62,4 +62,53 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_1, 20) +class torch_sum_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +ReduceSum op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.sum"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.axes") != captured_params.end()) + { + op->params["dim"] = captured_params.at("op_0.axes"); + } + else + { + // reduce all + const int input_rank = (int)op->inputs[0]->shape.size(); + std::vector dim(input_rank); + for (int i = 0; i < input_rank; i++) + { + dim[i] = i; + } + op->params["dim"] = dim; + } + + if (captured_params.find("op_0.keepdims") != captured_params.end()) + { + op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false; + } + else + { + op->params["keepdim"] = true; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_onnx, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp index 55661366c59..2a9f3b837b1 100644 --- a/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp +++ b/tools/pnnx/src/pass_level5/fuse_multiheadattention.cpp @@ -1574,6 +1574,444 @@ pnnx.Output output 1 0 out } }; +class fuse_multiheadattention_pass_onnx : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +23 22 +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +nn.Linear op_0 1 1 query 10 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 11 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 12 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 10 13 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 11 15 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 12 16 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 13 14 dims=(1,0,2) +Tensor.permute op_7 1 1 15 19 dims=(1,2,0) +Tensor.permute op_8 1 1 16 17 dims=(1,0,2) +pnnx.Expression op_9 1 1 14 18 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_10 2 1 18 19 20 +F.softmax softmax 1 1 20 21 dim=%softmax_dim +torch.matmul op_12 2 1 21 17 22 +Tensor.permute op_13 1 1 22 23 dims=(1,0,2) +Tensor.reshape op_14 1 1 23 24 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 24 25 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_16 1 1 25 out shape=(%qsize,%batch,%embed_dim) +Tensor.reshape op_17 1 1 21 27 shape=(%batch,%num_heads,%qsize,%kvsize) +torch.mean op_18 1 1 27 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +5 5 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +nn.MultiheadAttention attention 3 2 query key value out outweight embed_dim=%embed_dim kdim=%kdim vdim=%vdim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_1 : public fuse_multiheadattention_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +26 25 +pnnx.Input input_q 0 1 input +nn.Linear op_0 1 1 input 14 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 14 15 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 15 16 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 16 17 dim=3 +torch.unbind op_4 1 3 17 18 19 20 dim=0 +Tensor.reshape op_5 1 1 18 21 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 19 23 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 20 25 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_8 1 1 21 22 dims=(1,0,2) +Tensor.permute op_9 1 1 23 24 dims=(1,0,2) +Tensor.permute op_10 1 1 25 26 dims=(1,0,2) +Tensor.reshape op_11 1 1 22 27 shape=(%batch,%num_heads,%size,%feat_per_head) +Tensor.reshape op_12 1 1 24 28 shape=(%batch,%num_heads,%size,%feat_per_head) +Tensor.reshape op_13 1 1 26 29 shape=(%batch,%num_heads,%size,%feat_per_head) +Tensor.permute op_14 1 1 28 30 dims=(0,1,3,2) +pnnx.Expression op_15 1 1 27 31 expr=mul(@0,%sqrt_inv_sqrt_embed_dim_per_head) +pnnx.Expression op_16 1 1 30 32 expr=mul(@0,%sqrt_inv_sqrt_embed_dim_per_head) +torch.matmul op_17 2 1 31 32 33 +F.softmax softmax 1 1 33 34 dim=%softmax_dim +torch.matmul op_19 2 1 34 29 35 +Tensor.permute op_20 1 1 35 36 dims=(2,0,1,3) +Tensor.reshape op_21 1 1 36 37 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 37 38 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_23 1 1 38 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention attention 1 1 input out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& matched_operators, const std::map& captured_params, const std::map& /*captured_attrs*/) const + { + const int embed_dim = captured_params.at("embed_dim").i; + const int qkv_out_features = captured_params.at("qkv_out_features").i; + const int num_heads = captured_params.at("num_heads").i; + const int feat_per_head = captured_params.at("feat_per_head").i; + const float sqrt_inv_sqrt_embed_dim_per_head = captured_params.at("sqrt_inv_sqrt_embed_dim_per_head").f; + const int softmax_dim = captured_params.at("softmax_dim").i; + + if (qkv_out_features != embed_dim * 3) + return false; + + if (embed_dim != num_heads * feat_per_head) + return false; + + if (!NearlyEqual(sqrt_inv_sqrt_embed_dim_per_head, sqrt(1.f / sqrt(feat_per_head)), 0.001)) + return false; + + int softmax_input_rank = (int)matched_operators.at("softmax")->inputs[0]->shape.size(); + if (softmax_dim != -1 && softmax_dim != softmax_input_rank - 1) + return false; + + return true; + } +}; + +class fuse_multiheadattention_pass_onnx_1_1 : public fuse_multiheadattention_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +21 20 +pnnx.Input input_q 0 1 input +nn.Linear op_0 1 1 input 33 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 33 34 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 34 35 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 35 36 dim=3 +torch.unbind op_4 1 3 36 37 38 39 dim=0 +Tensor.reshape op_5 1 1 37 40 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_6 1 1 38 42 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 39 43 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_8 1 1 40 41 dims=(1,0,2) +Tensor.permute op_9 1 1 42 46 dims=(1,2,0) +Tensor.permute op_10 1 1 43 44 dims=(1,0,2) +pnnx.Expression op_11 1 1 41 45 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_12 2 1 45 46 47 +F.softmax softmax 1 1 47 48 dim=%softmax_dim +torch.matmul op_14 2 1 48 44 49 +Tensor.permute op_15 1 1 49 50 dims=(1,0,2) +Tensor.reshape op_16 1 1 50 51 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 51 52 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 52 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.MultiheadAttention attention 1 1 input out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_2 : public fuse_multiheadattention_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 attn_mask +nn.Linear op_0 1 1 input 15 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 15 16 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 16 17 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 17 18 dim=3 +torch.unbind op_4 1 3 18 19 20 21 dim=0 +Tensor.reshape op_5 1 1 19 23 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 20 25 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_8 1 1 21 26 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 23 24 dims=(1,0,2) +Tensor.permute op_11 1 1 25 29 dims=(1,2,0) +Tensor.permute op_9 1 1 26 27 dims=(1,0,2) +pnnx.Expression op_10 1 1 24 28 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_12 2 1 28 29 30 +torch.unsqueeze op_13 1 1 attn_mask 22 dim=0 +pnnx.Expression op_14 2 1 30 22 31 expr=add(@0,@1) +F.softmax softmax 1 1 31 32 dim=%softmax_dim +torch.matmul op_16 2 1 32 27 33 +Tensor.permute op_17 1 1 33 34 dims=(1,0,2) +Tensor.reshape op_18 1 1 34 35 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 35 36 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 36 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 attn_mask +nn.MultiheadAttention attention 2 1 input attn_mask out embed_dim=%embed_dim kdim=%embed_dim vdim=%embed_dim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_2_1 : public fuse_multiheadattention_pass_onnx_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +24 23 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 attn_mask +nn.Linear op_0 1 1 input 15 bias=%qkvbias in_features=%embed_dim out_features=%qkv_out_features @bias @weight +Tensor.reshape op_1 1 1 15 16 shape=(%batch,%size,1,3,%embed_dim) +Tensor.permute op_2 1 1 16 17 dims=(3,1,2,0,4) +torch.squeeze op_3 1 1 17 18 dim=3 +torch.unbind op_4 1 3 18 19 20 21 dim=0 +Tensor.reshape op_5 1 1 19 23 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 20 25 shape=(%size,%num_heads,%feat_per_head) +Tensor.reshape op_8 1 1 21 26 shape=(%size,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 23 24 dims=(1,0,2) +Tensor.permute op_11 1 1 25 29 dims=(1,2,0) +Tensor.permute op_9 1 1 26 27 dims=(1,0,2) +pnnx.Expression op_10 1 1 24 28 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_12 2 1 28 29 30 +torch.unsqueeze op_13 1 1 attn_mask 22 dim=0 +pnnx.Expression op_14 2 1 30 22 31 expr=add(@0,@1) +F.softmax softmax 1 1 31 32 dim=%softmax_dim +torch.matmul op_16 2 1 32 27 33 +Tensor.permute op_17 1 1 33 34 dims=(1,0,2) +Tensor.reshape op_18 1 1 34 35 shape=(%size,%embed_dim) +nn.Linear out_proj 1 1 35 36 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_20 1 1 36 out shape=(%size,%batch,%embed_dim) +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_3 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 query +pnnx.Input input_kv 0 1 kv +nn.Linear op_0 1 1 query 14 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 kv 15 bias=%kvbias in_features=%kvdim out_features=%kv_embed_dim @bias @weight +Tensor.reshape op_2 1 1 15 16 shape=(%batch,%kvsize,1,2,%embed_dim) +Tensor.permute op_3 1 1 16 17 dims=(3,1,2,0,4) +torch.squeeze op_4 1 1 17 18 dim=3 +torch.unbind op_5 1 2 18 19 20 dim=0 +Tensor.reshape op_6 1 1 14 21 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_7 1 1 19 23 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_8 1 1 20 24 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_9 1 1 21 22 dims=(1,0,2) +Tensor.permute op_10 1 1 24 25 dims=(1,0,2) +Tensor.permute op_11 1 1 23 27 dims=(1,2,0) +pnnx.Expression op_12 1 1 22 26 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_13 2 1 26 27 28 +F.softmax softmax 1 1 28 29 dim=%softmax_dim +torch.matmul op_15 2 1 29 25 30 +Tensor.permute op_16 1 1 30 31 dims=(1,0,2) +Tensor.reshape op_17 1 1 31 32 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 32 33 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_19 1 1 33 out shape=(%qsize,1,%embed_dim) +Tensor.reshape op_20 1 1 29 35 shape=(1,%num_heads,%qsize,%kvsize) +torch.mean op_21 1 1 35 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 4 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 kv +nn.MultiheadAttention attention 2 2 query kv out outweight embed_dim=%embed_dim kdim=%kvdim vdim=%kvdim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(ops, captured_params, captured_attrs); + + Operator* op = ops.at("attention"); + + const int embed_dim = captured_params.at("embed_dim").i; + const bool qbias = captured_params.at("qbias").b; + const bool kvbias = captured_params.at("kvbias").b; + const bool outbias = captured_params.at("outbias").b; + const bool bias = qbias || kvbias || outbias; + + op->params["bias"] = bias; + + op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight") + captured_attrs.at("op_1.weight"); + + op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight"); + + if (bias) + { + op->attrs["in_proj_bias"] = Attribute(); + op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type; + op->attrs["in_proj_bias"].shape = {embed_dim * 3}; + // combine qkv bias + std::vector in_proj_bias(embed_dim * 3); + { + float* in_proj_bias_ptr = (float*)in_proj_bias.data(); + if (qbias) + { + auto qb = captured_attrs.at("op_0.bias").get_float32_data(); + memcpy(in_proj_bias_ptr, (const void*)qb.data(), embed_dim * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float)); + } + in_proj_bias_ptr += embed_dim; + if (kvbias) + { + auto kvb = captured_attrs.at("op_1.bias").get_float32_data(); + memcpy(in_proj_bias_ptr, (const void*)kvb.data(), embed_dim * 2 * sizeof(float)); + } + else + { + memset(in_proj_bias_ptr, 0, embed_dim * 2 * sizeof(float)); + } + } + op->attrs["in_proj_bias"].set_float32_data(in_proj_bias); + + if (outbias) + { + op->attrs["out_proj.bias"] = captured_attrs.at("out_proj.bias"); + } + else + { + // init bias as zero + op->attrs["out_proj.bias"] = Attribute(); + op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type; + op->attrs["out_proj.bias"].shape = {embed_dim}; + op->attrs["out_proj.bias"].set_float32_data(std::vector(embed_dim, 0.f)); + } + } + } +}; + +class fuse_multiheadattention_pass_onnx_4 : public fuse_multiheadattention_pass_qkv +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +26 25 +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +pnnx.Input input_3 0 1 attn_mask +nn.Linear op_0 1 1 query 20 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 21 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 22 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 20 24 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 21 26 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 22 27 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 24 25 dims=(1,0,2) +Tensor.permute op_7 1 1 26 30 dims=(1,2,0) +Tensor.permute op_8 1 1 27 28 dims=(1,0,2) +pnnx.Expression op_9 1 1 25 29 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_10 2 1 29 30 31 +torch.unsqueeze op_11 1 1 attn_mask 23 dim=0 +pnnx.Expression op_12 2 1 31 23 32 expr=add(@0,@1) +F.softmax softmax 1 1 32 33 dim=%softmax_dim +torch.matmul op_14 2 1 33 28 34 +Tensor.permute op_15 1 1 34 35 dims=(1,0,2) +Tensor.reshape op_16 1 1 35 36 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 36 37 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_18 1 1 37 out shape=(%qsize,%batch,%embed_dim) +Tensor.reshape op_19 1 1 33 39 shape=(%batch,%num_heads,%qsize,%kvsize) +torch.mean op_20 1 1 39 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +6 6 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +nn.MultiheadAttention attention 4 2 query key value attn_mask out outweight embed_dim=%embed_dim kdim=%kdim vdim=%vdim num_heads=%num_heads batch_first=False add_zero_attn=False add_bias_kv=False $attn_mask=attn_mask +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } +}; + +class fuse_multiheadattention_pass_onnx_4_1 : public fuse_multiheadattention_pass_onnx_4 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +25 24 +pnnx.Input input_q 0 1 query +pnnx.Input input_k 0 1 key +pnnx.Input input_v 0 1 value +pnnx.Input input_3 0 1 attn_mask +nn.Linear op_0 1 1 query 22 bias=%qbias in_features=%embed_dim out_features=%embed_dim @bias @weight +nn.Linear op_1 1 1 key 23 bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight +nn.Linear op_2 1 1 value 24 bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight +Tensor.reshape op_3 1 1 22 25 shape=(%qsize,%num_heads,%feat_per_head) +Tensor.reshape op_4 1 1 23 27 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.reshape op_5 1 1 24 28 shape=(%kvsize,%num_heads,%feat_per_head) +Tensor.permute op_6 1 1 25 26 dims=(1,0,2) +Tensor.permute op_7 1 1 28 29 dims=(1,0,2) +Tensor.permute op_8 1 1 27 31 dims=(1,2,0) +pnnx.Expression op_9 1 1 26 30 expr=mul(@0,%inv_sqrt_embed_dim_per_head) +torch.matmul op_10 2 1 30 31 32 +pnnx.Expression op_11 2 1 32 attn_mask 33 expr=add(@0,@1) +F.softmax softmax 1 1 33 34 dim=%softmax_dim +torch.matmul op_13 2 1 34 29 35 +Tensor.permute op_14 1 1 35 36 dims=(1,0,2) +Tensor.reshape op_15 1 1 36 37 shape=(%qsize,%embed_dim) +nn.Linear out_proj 1 1 37 38 bias=%outbias in_features=%embed_dim out_features=%embed_dim @bias @weight +Tensor.reshape op_16 1 1 38 out shape=(%qsize,%batch,%embed_dim) +Tensor.reshape op_18 1 1 34 40 shape=(%batch,%num_heads,%qsize,%kvsize) +torch.mean op_19 1 1 40 outweight dim=(1) keepdim=False +pnnx.Output output 2 0 out outweight +)PNNXIR"; + } +}; + void fuse_multiheadattention(Graph& graph) { #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9) @@ -1606,6 +2044,14 @@ void fuse_multiheadattention(Graph& graph) fuse_multiheadattention_pass_17_1 p1; fuse_multiheadattention_pass_18 q; fuse_multiheadattention_pass_18_1 q1; + + fuse_multiheadattention_pass_onnx onnx0; + fuse_multiheadattention_pass_onnx_1 onnx1; + fuse_multiheadattention_pass_onnx_1_1 onnx1a; + fuse_multiheadattention_pass_onnx_2 onnx2; + fuse_multiheadattention_pass_onnx_3 onnx3; + fuse_multiheadattention_pass_onnx_4 onnx4; + fuse_multiheadattention_pass_onnx_4_1 onnx4a; int opindex = 0; pnnx_graph_rewrite(graph, &a, opindex); @@ -1637,6 +2083,14 @@ void fuse_multiheadattention(Graph& graph) pnnx_graph_rewrite(graph, &p1, opindex); pnnx_graph_rewrite(graph, &q, opindex); pnnx_graph_rewrite(graph, &q1, opindex); + + pnnx_graph_rewrite(graph, &onnx0, opindex); + pnnx_graph_rewrite(graph, &onnx1, opindex); + pnnx_graph_rewrite(graph, &onnx1a, opindex); + pnnx_graph_rewrite(graph, &onnx2, opindex); + pnnx_graph_rewrite(graph, &onnx3, opindex); + pnnx_graph_rewrite(graph, &onnx4, opindex); + pnnx_graph_rewrite(graph, &onnx4a, opindex); #endif } diff --git a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp index 18268d0f0fb..a3021d33c90 100644 --- a/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp +++ b/tools/pnnx/src/pass_onnx/fuse_constant_as_attribute.cpp @@ -36,7 +36,11 @@ static constant_as_attribute caas[] = { {"If", 0, "cond"}, {"Pad", 1, "pads"}, {"Pad", 2, "value"}, + {"ReduceMax", 1, "axes"}, {"ReduceMean", 1, "axes"}, + {"ReduceMin", 1, "axes"}, + {"ReduceProd", 1, "axes"}, + {"ReduceSum", 1, "axes"}, {"Reshape", 1, "shape"}, {"Resize", 2, "scales"}, {"Resize", 3, "sizes"}, diff --git a/tools/pnnx/tests/onnx/CMakeLists.txt b/tools/pnnx/tests/onnx/CMakeLists.txt index 8ed4b5e480a..12d816cd8e2 100644 --- a/tools/pnnx/tests/onnx/CMakeLists.txt +++ b/tools/pnnx/tests/onnx/CMakeLists.txt @@ -19,6 +19,10 @@ pnnx_onnx_add_test(F_conv3d) pnnx_onnx_add_test(F_elu) pnnx_onnx_add_test(F_gelu) # pnnx_onnx_add_test(F_group_norm) +pnnx_onnx_add_test(F_hardshrink) +pnnx_onnx_add_test(F_hardsigmoid) +pnnx_onnx_add_test(F_hardswish) +pnnx_onnx_add_test(F_hardtanh) # pnnx_onnx_add_test(F_instance_norm) pnnx_onnx_add_test(F_interpolate) pnnx_onnx_add_test(F_layer_norm) @@ -59,6 +63,10 @@ pnnx_onnx_add_test(nn_ELU) pnnx_onnx_add_test(nn_GELU) pnnx_onnx_add_test(nn_GroupNorm) pnnx_onnx_add_test(nn_GRU) +pnnx_onnx_add_test(nn_Hardshrink) +pnnx_onnx_add_test(nn_Hardsigmoid) +pnnx_onnx_add_test(nn_Hardswish) +pnnx_onnx_add_test(nn_Hardtanh) pnnx_onnx_add_test(nn_InstanceNorm1d) pnnx_onnx_add_test(nn_InstanceNorm2d) pnnx_onnx_add_test(nn_InstanceNorm3d) @@ -70,7 +78,7 @@ pnnx_onnx_add_test(nn_LSTM) pnnx_onnx_add_test(nn_MaxPool1d) pnnx_onnx_add_test(nn_MaxPool2d) pnnx_onnx_add_test(nn_MaxPool3d) -# pnnx_onnx_add_test(nn_MultiheadAttention) +pnnx_onnx_add_test(nn_MultiheadAttention) pnnx_onnx_add_test(nn_PReLU) pnnx_onnx_add_test(nn_ReflectionPad1d) pnnx_onnx_add_test(nn_ReflectionPad2d) diff --git a/tools/pnnx/tests/onnx/test_F_hardshrink.py b/tools/pnnx/tests/onnx/test_F_hardshrink.py index 52d44ff6c15..00dbdd6fd86 100644 --- a/tools/pnnx/tests/onnx/test_F_hardshrink.py +++ b/tools/pnnx/tests/onnx/test_F_hardshrink.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -32,6 +33,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_F_hardsigmoid.py b/tools/pnnx/tests/onnx/test_F_hardsigmoid.py index 53c3adaefa5..a0f8c7654d3 100644 --- a/tools/pnnx/tests/onnx/test_F_hardsigmoid.py +++ b/tools/pnnx/tests/onnx/test_F_hardsigmoid.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version def hardsigmoid_forward_0(x): return F.relu6(x + 3., True) / 6. @@ -49,6 +50,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_F_hardswish.py b/tools/pnnx/tests/onnx/test_F_hardswish.py index cf671d4c611..78ada9b3482 100644 --- a/tools/pnnx/tests/onnx/test_F_hardswish.py +++ b/tools/pnnx/tests/onnx/test_F_hardswish.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version def hardswish_forward_0(x): return x * F.hardsigmoid(x) @@ -46,6 +47,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_nn_Hardshrink.py b/tools/pnnx/tests/onnx/test_nn_Hardshrink.py index c24e48f5bcb..38dd67c464f 100644 --- a/tools/pnnx/tests/onnx/test_nn_Hardshrink.py +++ b/tools/pnnx/tests/onnx/test_nn_Hardshrink.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -35,6 +36,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.11'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py b/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py index 35d02bef202..43af7471411 100644 --- a/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py +++ b/tools/pnnx/tests/onnx/test_nn_Hardsigmoid.py @@ -15,6 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version class Model(nn.Module): def __init__(self): @@ -34,6 +35,9 @@ def forward(self, x, y, z, w): return x, y, z, w def test(): + if version.parse(torch.__version__) < version.parse('1.9'): + return True + net = Model() net.eval() diff --git a/tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py b/tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py new file mode 100644 index 00000000000..9d6cc263ab2 --- /dev/null +++ b/tools/pnnx/tests/onnx/test_nn_MultiheadAttention.py @@ -0,0 +1,138 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.attention_0_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4) + self.attention_0_1 = nn.MultiheadAttention(embed_dim=64, num_heads=8, bias=False, add_bias_kv=False, add_zero_attn=False) + self.attention_0_2 = nn.MultiheadAttention(embed_dim=64, num_heads=16, bias=True, add_bias_kv=False, add_zero_attn=False) + + self.attention_0_3 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True) + self.attention_0_33 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True) + + self.attention_0_4 = nn.MultiheadAttention(embed_dim=40, num_heads=4, kdim=30, vdim=20) + self.attention_0_5 = nn.MultiheadAttention(embed_dim=40, num_heads=8, kdim=30, vdim=20, bias=False, add_bias_kv=False, add_zero_attn=False) + self.attention_0_6 = nn.MultiheadAttention(embed_dim=40, num_heads=10, kdim=30, vdim=20, bias=True, add_bias_kv=False, add_zero_attn=False) + + if version.parse(torch.__version__) >= version.parse('1.9'): + self.attention_1_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4, batch_first=True) + self.attention_1_1 = nn.MultiheadAttention(embed_dim=64, num_heads=8, bias=False, add_bias_kv=False, add_zero_attn=False, batch_first=True) + self.attention_1_2 = nn.MultiheadAttention(embed_dim=64, num_heads=16, bias=True, add_bias_kv=False, add_zero_attn=False, batch_first=True) + + self.attention_1_3 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True, batch_first=True) + self.attention_1_33 = nn.MultiheadAttention(embed_dim=32, num_heads=8, bias=True, batch_first=True) + + self.attention_1_4 = nn.MultiheadAttention(embed_dim=40, num_heads=4, kdim=30, vdim=20, batch_first=True) + self.attention_1_5 = nn.MultiheadAttention(embed_dim=40, num_heads=8, kdim=30, vdim=20, bias=False, add_bias_kv=False, add_zero_attn=False, batch_first=True) + self.attention_1_6 = nn.MultiheadAttention(embed_dim=40, num_heads=10, kdim=30, vdim=20, bias=True, add_bias_kv=False, add_zero_attn=False, batch_first=True) + + def forward(self, xq, xk, xv, z, zmask, yq, yk, yv, ymask, ymask2): + x0, x0w = self.attention_0_0(xq, xk, xv) + x1, x1w = self.attention_0_1(xq, xk, xv) + x2, x2w = self.attention_0_2(xq, xk, xk) + + x3, _ = self.attention_0_3(z, z, z, need_weights=False) + x33, _ = self.attention_0_33(z, z, z, attn_mask=zmask) + + x4, x4w = self.attention_0_4(yq, yk, yv) + x5, x5w = self.attention_0_5(yq, yk, yv, attn_mask=ymask) + x6, x6w = self.attention_0_6(yq, yk, yv, attn_mask=ymask2) + + if version.parse(torch.__version__) < version.parse('1.9'): + return x0, x0w, x1, x1w, x2, x2w, x3, x33, x4, x4w, x5, x5w, x6, x6w + + xq = xq.transpose(0, 1) + xk = xk.transpose(0, 1) + xv = xv.transpose(0, 1) + z = z.transpose(0, 1) + yq = yq.transpose(0, 1) + yk = yk.transpose(0, 1) + yv = yv.transpose(0, 1) + + y0, y0w = self.attention_1_0(xq, xk, xv) + y1, y1w = self.attention_1_1(xq, xk, xv) + y2, y2w = self.attention_1_2(xq, xk, xk) + + y3, _ = self.attention_1_3(z, z, z) + if version.parse(torch.__version__) >= version.parse('1.12') and version.parse(torch.__version__) < version.parse('1.13'): + # HACK pytorch 1.12 breaks 2-dim zmask + # https://github.com/pytorch/pytorch/issues/97409 + # zmask2 = zmask.reshape(1, 1, 30, 30).expand(1, 8, 30, 30) + # y33, _ = self.attention_1_33(z, z, z, attn_mask=zmask2) + # but it produce all nan then, skip test :( + y33 = y3.relu() + elif version.parse(torch.__version__) >= version.parse('2.0') and version.parse(torch.__version__) < version.parse('2.1'): + # HACK pytorch 2.0 produce all nan, skip test :( + y33 = y3.relu() + else: + y33, _ = self.attention_1_33(z, z, z, attn_mask=zmask.relu()) + + y4, y4w = self.attention_1_4(yq, yk, yv) + y5, y5w = self.attention_1_5(yq, yk, yv, attn_mask=ymask.relu()) + y6, y6w = self.attention_1_6(yq, yk, yv, attn_mask=ymask2.relu()) + + return x0, x0w, x1, x1w, x2, x2w, x3, x33, x4, x4w, x5, x5w, x6, x6w, y0, y0w, y1, y1w, y2, y2w, y3, y33, y4, y4w, y5, y5w, y6, y6w + +def test(): + if version.parse(torch.__version__) < version.parse('1.10'): + return True + + if version.parse(torch.__version__) >= version.parse('2.0') and version.parse(torch.__version__) < version.parse('2.1'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + xq = torch.rand(20, 1, 64) + xk = torch.rand(20, 1, 64) + xv = torch.rand(20, 1, 64) + z = torch.rand(30, 1, 32) + zmask = torch.rand(30, 30) + yq = torch.rand(15, 1, 40) + yk = torch.rand(24, 1, 30) + yv = torch.rand(24, 1, 20) + ymask = torch.rand(15, 24) + ymask2 = torch.rand(10, 15, 24) + + a = net(xq, xk, xv, z, zmask, yq, yk, yv, ymask, ymask2) + + # export onnx + torch.onnx.export(net, (xq, xk, xv, z, zmask, yq, yk, yv, ymask, ymask2), "test_nn_MultiheadAttention.onnx") + + # onnx to pnnx + import os + os.system("../../src/pnnx test_nn_MultiheadAttention.onnx inputshape=[20,1,64],[20,1,64],[20,1,64],[30,1,32],[30,30],[15,1,40],[24,1,30],[24,1,20],[15,24],[10,15,24]") + + # pnnx inference + import test_nn_MultiheadAttention_pnnx + b = test_nn_MultiheadAttention_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)