From 28de742d5270d2df3beb1713538967bf8a6962dd Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sat, 31 Jul 2021 05:39:51 -0700 Subject: [PATCH] [Refactor] Unify the shared pass prefix between vm and graph (#8526) --- src/relay/backend/build_module.cc | 52 +----------------------- src/relay/backend/utils.cc | 67 +++++++++++++++++++++++++++++++ src/relay/backend/utils.h | 17 ++++++++ src/relay/backend/vm/compiler.cc | 52 +----------------------- 4 files changed, 86 insertions(+), 102 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ea53c34c793b..f407436e5868 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -313,57 +313,7 @@ class RelayBuildModule : public runtime::ModuleNode { relay_module_ptr->Update(main_glb_var, new_main); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - pass_seqs.push_back(transform::SimplifyInference()); - - // Convert Dynamic ops to static versions - pass_seqs.push_back(transform::DynamicToStatic()); - - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - *rv = false; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::InferType()); - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = GetPassPrefix(targets, false); if (targets.size() == 1) { const auto& target = (*targets.begin()).second; diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index f0c543f1244b..4b4844599e29 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -24,6 +24,8 @@ #include "utils.h" +#include + namespace tvm { namespace relay { namespace backend { @@ -120,6 +122,71 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); +Array GetPassPrefix(const Map& targets, bool is_vm) { + Array pass_seqs; + Array entry_functions{"main"}; + pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); + pass_seqs.push_back(transform::ToBasicBlockNormalForm()); + // Run all dialect legalization passes. + pass_seqs.push_back(relay::qnn::transform::Legalize()); + + // Legalize pass is restricted to homogeneous execution for now. + if (targets.size() == 1) { + pass_seqs.push_back(transform::Legalize()); + } + + pass_seqs.push_back(transform::SimplifyInference()); + + if (is_vm) { + // eta expand to support constructors in argument position + pass_seqs.push_back(transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)); + } else { + // Convert Dynamic ops to static versions + pass_seqs.push_back(transform::DynamicToStatic()); + } + + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == DataType::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }); + pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); + if (is_vm) { + pass_seqs.push_back(transform::InlinePrimitives()); + } + pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); + pass_seqs.push_back(transform::FoldConstant()); + pass_seqs.push_back(transform::FoldScaleAxis()); + pass_seqs.push_back(transform::CanonicalizeCast()); + pass_seqs.push_back(transform::CanonicalizeOps()); + + // Alter layout transformation is only applied to homogeneous execution yet. + if (targets.size() == 1) { + if (!is_vm) { + pass_seqs.push_back(transform::InferType()); + } + pass_seqs.push_back(transform::AlterOpLayout()); + } + + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); + pass_seqs.push_back(transform::FoldConstant()); + return pass_seqs; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index d2a173a43f46..a0c7a5aad26d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -44,7 +44,12 @@ namespace tvm { namespace relay { +namespace transform { +Pass InlinePrimitives(); +} + namespace backend { +using Pass = tvm::transform::Pass; /*! * \brief The static storage information produced by memory planning. @@ -410,6 +415,18 @@ inline bool IsCompileEngineCacheDisabled() { .value(); } +/*! + * \brief Get the sequence of Relay optimization passes based on backend type. + * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight + * difference. This function unifies the shared optimization pass prefix between vm and graph + * runtime, and returns the pass prefix given the backend type. + * + * \param targets The device type to `Target` mapping. + * \param is_vm A boolean indicating if the passes are used for vm or graph runtime. + * \return An array of passes. + */ +Array GetPassPrefix(const Map& targets, bool is_vm); + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 96aa77f286a9..ddb1911a6b71 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1042,57 +1042,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg, mod->Add(gvar, f); } - Array pass_seqs; - Array entry_functions{"main"}; - pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); - pass_seqs.push_back(transform::ToBasicBlockNormalForm()); - // Run all dialect legalization passes. - pass_seqs.push_back(relay::qnn::transform::Legalize()); - - // Legalize pass is restricted to homogeneous execution for now. - if (targets.size() == 1) { - pass_seqs.push_back(transform::Legalize()); - } - - // eta expand to support constructors in argument position - pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); - - pass_seqs.push_back(transform::SimplifyInference()); - PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == DataType::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }); - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); - pass_seqs.push_back(transform::SimplifyExpr()); - pass_seqs.push_back(transform::InlinePrimitives()); - - pass_seqs.push_back(transform::CombineParallelConv2D(3)); - pass_seqs.push_back(transform::CombineParallelDense(3)); - pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); - pass_seqs.push_back(transform::FoldConstant()); - pass_seqs.push_back(transform::FoldScaleAxis()); - pass_seqs.push_back(transform::CanonicalizeCast()); - pass_seqs.push_back(transform::CanonicalizeOps()); - - // Alter layout transformation is only applied to homogeneous execution yet. - if (targets.size() == 1) { - pass_seqs.push_back(transform::AlterOpLayout()); - } - - // Fast math optimizations. - pass_seqs.push_back(transform::FastMath()); - pass_seqs.push_back(transform::FoldConstant()); + Array pass_seqs = relay::backend::GetPassPrefix(targets, true); if (targets_.size() > 1) { // Handle heterogeneous compilation.