Skip to content

Commit

Permalink
pnnx reset onnx input shape, convert torch.tile torch.where (#5517)
Browse files Browse the repository at this point in the history
* pnnx reset onnx input shape

* eliminate noop cast
  • Loading branch information
nihui authored Jun 19, 2024
1 parent b786af5 commit 2828e7a
Show file tree
Hide file tree
Showing 19 changed files with 620 additions and 68 deletions.
4 changes: 3 additions & 1 deletion tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,15 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_sum.cpp
pass_level2/torch_t.cpp
pass_level2/torch_tensor_split.cpp
pass_level2/torch_tile.cpp
pass_level2/torch_topk.cpp
pass_level2/torch_transpose.cpp
pass_level2/torch_unbind.cpp
pass_level2/torch_unsqueeze.cpp
pass_level2/torch_var.cpp
pass_level2/torch_view_as_complex.cpp
pass_level2/torch_view_as_real.cpp
pass_level2/torch_view_as_real.cpp
pass_level2/torch_where.cpp
pass_level2/torch_zeros.cpp
pass_level2/torch_zeros_like.cpp
pass_level2/torch_stft.cpp
Expand Down
50 changes: 10 additions & 40 deletions tools/pnnx/src/load_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,16 +564,6 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph,

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "dead_code_elimination ... ");

t0 = get_current_time();

onnx2pnnx::dead_code_elimination(model);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "fold_constants ... ");

t0 = get_current_time();
Expand All @@ -584,51 +574,51 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph,

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "dead_code_elimination ... ");
fprintf(stderr, "%-34s", "canonicalize ... ");

t0 = get_current_time();

onnx2pnnx::dead_code_elimination(model);
onnx2pnnx::canonicalize(model);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "canonicalize ... ");
fprintf(stderr, "%-34s", "shape_inference ... ");

t0 = get_current_time();

onnx2pnnx::canonicalize(model);
onnx2pnnx::shape_inference(model, input_shapes, input_types, input_shapes2, input_types2);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "shape_inference ... ");
fprintf(stderr, "%-34s", "fold_constants_dynamic_shape ... ");

t0 = get_current_time();

onnx2pnnx::shape_inference(model, input_shapes, input_types, input_shapes2, input_types2);
onnx2pnnx::fold_constants_dynamic_shape(model, input_shapes, input_types);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "fold_constants_dynamic_shape ... ");
fprintf(stderr, "%-34s", "fuse_constant_as_attribute ... ");

t0 = get_current_time();

onnx2pnnx::fold_constants_dynamic_shape(model, input_shapes, input_types);
onnx2pnnx::fuse_constant_as_attribute(model);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "dead_code_elimination ... ");
fprintf(stderr, "%-34s", "eliminate_noop_with_shape ... ");

t0 = get_current_time();

onnx2pnnx::dead_code_elimination(model);
onnx2pnnx::eliminate_noop_with_shape(model);

t1 = get_current_time();

Expand All @@ -653,26 +643,6 @@ int load_onnx(const std::string& onnxpath, Graph& pnnx_graph,
}
}

fprintf(stderr, "%-34s", "fuse_constant_as_attribute ... ");

t0 = get_current_time();

onnx2pnnx::fuse_constant_as_attribute(model);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

fprintf(stderr, "%-34s", "dead_code_elimination ... ");

t0 = get_current_time();

onnx2pnnx::dead_code_elimination(model);

t1 = get_current_time();

fprintf(stderr, "%8.2fms\n", t1 - t0);

onnx2pnnx::ModelStat newstat = onnx2pnnx::get_model_stat(model);

onnx2pnnx::print_model_stat(oldstat, newstat);
Expand Down
28 changes: 28 additions & 0 deletions tools/pnnx/src/pass_level2/F_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,32 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding, 10)

class F_embedding_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 weight
Gather op_0 2 1 weight input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.embedding";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params["scale_grad_by_freq"] = false;
op->params["sparse"] = false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_onnx, 10)

} // namespace pnnx
55 changes: 55 additions & 0 deletions tools/pnnx/src/pass_level2/F_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,61 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_3, 9)

class F_gelu_4 : public GraphRewriterPass
{
public:
// (x * 0.5) * (tanh((x + (torch.pow(x, 3.0) * 0.044715)) * 7.978846e-01) + 1)
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 14
pnnx.Input input 0 1 input
prim::Constant op_0 0 1 zp5 value=0.5
aten::mul op_1 2 1 input zp5 pnnx_90
prim::Constant op_2 0 1 three value=%3
aten::pow op_3 2 1 input three pnnx_91
prim::Constant op_4 0 1 zp044715 value=%0p044715
aten::mul op_5 2 1 pnnx_91 zp044715 pnnx_92
aten::add op_6 2 1 input pnnx_92 pnnx_93
prim::Constant op_7 0 1 sqrt2dpi value=%sqrt2dpi
aten::mul op_8 2 1 pnnx_93 sqrt2dpi pnnx_94
aten::tanh op_9 1 1 pnnx_94 pnnx_95
prim::Constant op_10 0 1 one value=%1
aten::add op_11 2 1 pnnx_95 one pnnx_96
aten::mul op_12 2 1 pnnx_90 pnnx_96 out
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (fabs(captured_params.at("0p044715").f - 0.044715f) > 0.0001f)
return false;

if (fabs(captured_params.at("sqrt2dpi").f - sqrt(2.f / M_PI)) > 0.0001f)
return false;

if ((captured_params.at("1").type == 2 && captured_params.at("1").i != 1) || (captured_params.at("1").type == 3 && captured_params.at("1").f != 1.f))
return false;

if ((captured_params.at("3").type == 2 && captured_params.at("3").i != 3) || (captured_params.at("3").type == 3 && captured_params.at("3").f != 3.f))
return false;

return true;
}

const char* type_str() const
{
return "F.gelu";
}

void write(Operator* /*op*/, const std::map<std::string, Parameter>& /*captured_params*/) const
{
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_4, 9)

class F_gelu_onnx : public GraphRewriterPass
{
public:
Expand Down
28 changes: 28 additions & 0 deletions tools/pnnx/src/pass_level2/torch_arange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,32 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_arange_3, 20)

class torch_arange_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 start
pnnx.Input input_1 0 1 end
pnnx.Input input_2 0 1 step
Range op_0 3 1 start end step out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.arange";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dtype"] = Parameter();
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_arange_onnx, 20)

} // namespace pnnx
27 changes: 27 additions & 0 deletions tools/pnnx/src/pass_level2/torch_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,31 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full, 20)

class torch_full_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 size
ConstantOfShape op_0 1 1 size out value=%fill_value
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.full";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["fill_value"] = captured_params.at("fill_value");
op->params["dtype"] = Parameter();
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_onnx, 20)

} // namespace pnnx
63 changes: 63 additions & 0 deletions tools/pnnx/src/pass_level2/torch_tile.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_tile : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dims
aten::tile op_0 2 1 input dims out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.tile";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tile, 20)

class torch_tile_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 dims
Tile op_0 2 1 input dims out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.tile";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_tile_onnx, 20)

} // namespace pnnx
42 changes: 42 additions & 0 deletions tools/pnnx/src/pass_level2/torch_where.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_where : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 condition
pnnx.Input input_1 0 1 input
pnnx.Input input_2 0 1 other
aten::where op_0 3 1 condition input other out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.where";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_where, 20)

} // namespace pnnx
Loading

0 comments on commit 2828e7a

Please sign in to comment.