From 0936093139be63d557b23e23bfe3346ccb7d7b51 Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Wed, 21 Jul 2021 20:42:21 -0700 Subject: [PATCH 1/4] Unify the shared pass prefix between vm and graph --- src/relay/backend/build_module.cc | 52 +------------------------ src/relay/backend/utils.cc | 65 +++++++++++++++++++++++++++++++ src/relay/backend/utils.h | 15 +++++++ src/relay/backend/vm/compiler.cc | 52 +------------------------ 4 files changed, 82 insertions(+), 102 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ea53c34c793b..3377d1325de8 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -313,57 +313,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - pass_seqs.push_back(transform::SimplifyInference()); - - // Convert Dynamic ops to static versions - pass_seqs.push_back(transform::DynamicToStatic()); - - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - *rv = false; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::InferType()); - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = GetPrefixOpts(targets, false); if (targets.size() == 1) { const auto& target = (*targets.begin()).second; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index f0c543f1244b..340f1b1f0dc4 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -120,6 +120,71 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); +Array GetPrefixOpts(const Map& targets, bool is_vm) { + Array pass_seqs; + Array entry_functions{"main"}; + pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); + pass_seqs.push_back(transform::ToBasicBlockNormalForm()); + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + + pass_seqs.push_back(transform::SimplifyInference()); + + if (is_vm) { + // eta expand to support constructors in argument position + pass_seqs.push_back(transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)); + } else { + // Convert Dynamic ops to static versions + pass_seqs.push_back(transform::DynamicToStatic()); + } + + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == DataType::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); + if (is_vm) { + pass_seqs.push_back(transform::InlinePrimitives()); + } + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + if (!is_vm) { + pass_seqs.push_back(transform::InferType()); + } + pass_seqs.push_back(transform::AlterOpLayout()); + } + + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); + pass_seqs.push_back(transform::FoldConstant()); + return pass_seqs; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index d2a173a43f46..7f6663664f90 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -44,7 +45,12 @@ namespace tvm { namespace relay { +namespace transform { +Pass InlinePrimitives(); +} + namespace backend { +using Pass = tvm::transform::Pass; /*! * \brief The static storage information produced by memory planning. @@ -410,6 +416,15 @@ inline bool IsCompileEngineCacheDisabled() { .value(); } +/*! + * \brief Get a shared optimization pass prefix between vm and graph executor. + * + * \param targets The device type to `Target` mapping. + * \param is_vm A boolean indicating if the passes are used for vm or graph executor. + * \return An array of passes. + */ +Array GetPrefixOpts(const Map& targets, bool is_vm); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 96aa77f286a9..88723db7e1f1 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1042,57 +1042,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - // eta expand to support constructors in argument position - pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); - - pass_seqs.push_back(transform::SimplifyInference()); - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::InlinePrimitives()); - - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = relay::backend::GetPrefixOpts(targets, true); if (targets_.size() > 1) { // Handle heterogeneous compilation. From 6dde56833843f21584c7d80d3f514c8bc21292e3 Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Fri, 23 Jul 2021 23:01:58 -0700 Subject: [PATCH 2/4] retrigger ci From a5e23414796c7822364d08e904610fee301af9fb Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Sat, 24 Jul 2021 09:39:14 -0700 Subject: [PATCH 3/4] retrigger ci From bedf48b26e7f69b8401280753132a5480d90ff75 Mon Sep 17 00:00:00 2001 From: YuchenJin Date: Wed, 28 Jul 2021 20:44:41 -0700 Subject: [PATCH 4/4] Update. --- src/relay/backend/build_module.cc | 2 +- src/relay/backend/utils.cc | 4 +++- src/relay/backend/utils.h | 10 ++++++---- src/relay/backend/vm/compiler.cc | 2 +- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 3377d1325de8..f407436e5868 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -313,7 +313,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs = GetPrefixOpts(targets, false); + Array pass_seqs = GetPassPrefix(targets, false); if (targets.size() == 1) { const auto& target = (*targets.begin()).second; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 340f1b1f0dc4..4b4844599e29 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -24,6 +24,8 @@ #include "utils.h" +#include + namespace tvm { namespace relay { namespace backend { @@ -120,7 +122,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); -Array GetPrefixOpts(const Map& targets, bool is_vm) { +Array GetPassPrefix(const Map& targets, bool is_vm) { Array pass_seqs; Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 7f6663664f90..a0c7a5aad26d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include @@ -417,13 +416,16 @@ inline bool IsCompileEngineCacheDisabled() { } /*! - * \brief Get a shared optimization pass prefix between vm and graph executor. + * \brief Get the sequence of Relay optimization passes based on backend type. + * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight + * difference. This function unifies the shared optimization pass prefix between vm and graph + * runtime, and returns the pass prefix given the backend type. * * \param targets The device type to `Target` mapping. - * \param is_vm A boolean indicating if the passes are used for vm or graph executor. + * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. * \return An array of passes. */ -Array GetPrefixOpts(const Map& targets, bool is_vm); +Array GetPassPrefix(const Map& targets, bool is_vm); } // namespace backend } // namespace relay diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 88723db7e1f1..ddb1911a6b71 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1042,7 +1042,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs = relay::backend::GetPrefixOpts(targets, true); + Array pass_seqs = relay::backend::GetPassPrefix(targets, true); if (targets_.size() > 1) { // Handle heterogeneous compilation.