Skip to content

Commit

Permalink
[Unity][Transform] Improved canonicalization of non-dataflow Var
Browse files Browse the repository at this point in the history
Prior to this commit, `relax.transform.CanonicalizeBindings` removed
trivial bindings `var_y = var_x` where a `var_y: relax.DataflowVar`
and `var_x: relax.Var`, but did not remove trivial bindings when
`var_y: relax.Var` and `var_x: relax.DataflowVar`.  This was to avoid
invalid use of a `relax.DataflowVar` outside of a dataflow block.

This commit updates `CanonicalizeBindings` to handle this type of
binding as well.  To ensure that no `relax.DataflowVar` instances are
used outside of a dataflow block, this is done by replacing `var_y:
relax.DataflowVar` at its point of definition, instead of replacing
`var_x: relax.Var` at its point of use.

This commit also canonicalizes `relax.Var` definitions to
`relax.DataflowVar`, if the binding occurs within a dataflow block,
and the variable is never used outside of a dataflow block.
  • Loading branch information
Lunderberg committed Oct 17, 2023
1 parent 354c5f1 commit ddf30fb
Show file tree
Hide file tree
Showing 3 changed files with 420 additions and 76 deletions.
6 changes: 5 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
unchanged &= new_block.same_as(block);
}

this->BeginBindingBlock();
if (block_stack_.size() && CurrentBlockIsDataFlow()) {
this->BeginDataflowBlock();
} else {
this->BeginBindingBlock();
}
// the body may not be a leaf expression, so check for that
Expr new_body = this->NormalizeArgument(op->body);
unchanged &= new_body.same_as(op->body);
Expand Down
251 changes: 181 additions & 70 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,96 +32,207 @@
namespace tvm {
namespace relax {

class BindingCanonicalizer : public ExprMutator {
namespace {

struct CanonicalizationPlan {
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> replace_usage;
std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> replace_binding;
std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
};

/*! \brief Utility class to identify usage location
*
* Canonicalization of a variable binding may require information from
* later in the function. For example, replacing `dataflow_x = expr`
* with `var_x = expr` to avoid a trivial binding of `var_x =
* dataflow_x` later in the function. This utility examines a relax
* expression, and plans the changes to be made in a mutation pass.
*/
class CanonicalizePlanner : public ExprVisitor {
public:
BindingCanonicalizer() {}

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
if (auto explicit_tuple = tuple_value.as<TupleNode>()) {
CHECK_GE(tuple_get_item->index, 0)
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but negative indices are not supported in this context.";
CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size())
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but the tuple size is only " << explicit_tuple->fields.size();
return VisitExpr(explicit_tuple->fields[tuple_get_item->index]);
static CanonicalizationPlan Collect(const Expr& expr) {
CanonicalizePlanner visitor;
visitor.VisitExpr(expr);

CanonicalizationPlan plan;

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> handled;

for (const auto& binding_iter : visitor.trivial_bindings_) {
Var bound_var = binding_iter.first;
Var bound_to = binding_iter.second;

while (true) {
if (auto it = visitor.trivial_bindings_.find(bound_to);
it != visitor.trivial_bindings_.end()) {
// This may be a trivial binding into a trivial binding. In
// that case, unwrap the bindings until we find the earliest
// non-trivial binding.
bound_to = it->second;
} else {
break;
}
}

while (true) {
if (auto it = plan.replace_binding.find(bound_to->vid); it != plan.replace_binding.end()) {
// The variable we are binding to may have already been
// replaced, if it fell into Case 4 (Var = DataflowVar). In
// that case, we check against its replacement instead.
bound_to = it->second;
} else {
break;
}
}

if (bound_var.as<DataflowVarNode>() || !bound_to.as<DataflowVarNode>()) {
// Case 1: Var = Var
// Case 2: DataflowVar = Var
// Case 3: DataflowVar = DataflowVar
//
// For these three cases, the trivial binding can be
// unwrapped, using the bound variable directly at the point
// of use.
plan.replace_usage[bound_var->vid] = bound_to;
plan.bindings_to_remove.insert(bound_var->vid);
handled.insert(bound_to);
} else {
// Case 4: Var = DataflowVar
//
// Replacing a Var with a DataflowVar could result in illegal
// use of a DataflowVar outside of a DataflowBlock. Instead,
// we replace in the opposite direction, replacing the binding
// of the DataflowVar with a binding of the Var.
plan.replace_binding[bound_to->vid] = bound_var;
plan.replace_usage[bound_to->vid] = bound_var;
plan.bindings_to_remove.insert(bound_var->vid);
handled.insert(bound_var);
}
}
return ExprMutator::VisitExpr_(tuple_get_item);
}

void VisitBinding_(const VarBindingNode* binding) override {
// Unlike default visitor, we do not permit the checked type to change
// if the new value's checked type is different (this preserves user annotations)
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

if (auto opt_var = new_value.as<Var>();
opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
var_remap_[new_var->vid] = opt_var.value();
} else if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
} else {
this->builder_->EmitNormalized(VarBinding(new_var, new_value));
// If a Var has been defined inside a DataflowBlock, is only used
// within a DataflowBlock, and is not already handled by removal
// of trivial bindings, then we can replace it with a DataflowVar.
for (const auto& var : visitor.defined_inside_dataflow_) {
if (!var.as<DataflowVarNode>() && !visitor.used_outside_dataflow_.count(var) &&
!handled.count(var)) {
DataflowVar new_var(var->name_hint(), GetStructInfo(var));
plan.replace_binding[var->vid] = new_var;
plan.replace_usage[var->vid] = new_var;
}
}

return plan;
}

void VisitBinding_(const MatchCastNode* binding) override {
// If we have a trivial shape check (the shape_ of LHS and RHS is the same),
// we can canonicalize to a var binding
Expr new_value = this->VisitExpr(binding->value);
private:
void VisitBindingBlock_(const DataflowBlockNode* block) override {
bool cache = inside_dataflow_;
inside_dataflow_ = true;
ExprVisitor::VisitBindingBlock_(block);
inside_dataflow_ = cache;
}

bool has_same_struct_info = StructuralEqual()(binding->struct_info, GetStructInfo(new_value));
void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}

if (has_same_struct_info) {
if (auto parent = new_value.as<Var>();
parent && CanCanonicalizeVar(binding->var, parent.value())) {
// LHS and RHS have the same struct info, and occur in a
// context where the RHS can replace the LHS.
var_remap_[binding->var->vid] = parent.value();
} else {
// LHS and RHS have the same struct info, but the RHS is not a
// drop-in replacement for the LHS.
builder_->EmitNormalized(VarBinding(binding->var, new_value));
// Unwrap TupleGetItem, if the Tuple being accessed is known.
if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
Expr tuple = tuple_get_item->tuple;
while (true) {
if (auto tuple_var = tuple.as<Var>()) {
if (auto it = known_bindings_.find(tuple_var.value()); it != known_bindings_.end()) {
tuple = it->second;
continue;
}
}
break;
}
} else if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<MatchCast>(binding));
} else {
builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info));

if (auto ptr = tuple.as<TupleNode>()) {
value = ptr->fields[tuple_get_item->index];
}
}

if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
trivial_bindings_[binding->var] = parent.value();
}

known_bindings_[binding->var] = value;

ExprVisitor::VisitBinding(binding);
}

void VisitVarDef(const Var& var) override {
if (inside_dataflow_) {
defined_inside_dataflow_.insert(var);
}
}

void VisitExpr_(const VarNode* var) override {
if (!inside_dataflow_) {
used_outside_dataflow_.insert(GetRef<Var>(var));
}
}

bool inside_dataflow_{false};

std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> trivial_bindings_;
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> known_bindings_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_inside_dataflow_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_outside_dataflow_;
};

/*! \brief The mutator class to apply a CanonicalizationPlan */
class BindingCanonicalizer : public ExprMutator {
public:
static Expr Apply(Expr expr) {
auto used_outside_dataflow = CanonicalizePlanner::Collect(expr);
BindingCanonicalizer mutator(std::move(used_outside_dataflow));
return mutator.VisitExpr(expr);
}

private:
bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
std::function<bool(const ObjectRef&, const ObjectRef&)> check_eq) {
// annotations differ if one is present but not the other
// or they're both present and they differ
bool both_present = obj1.defined() && obj2.defined();
bool neither_present = !obj1.defined() && !obj2.defined();
return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2));
BindingCanonicalizer(CanonicalizationPlan plan) : plan_(plan) {}

void VisitBinding(const Binding& binding) override {
if (!plan_.bindings_to_remove.count(binding->var->vid)) {
ExprMutator::VisitBinding(binding);
}
}

bool CanCanonicalizeVar(Var var, Var parent_var) {
// Cases when we conservatively do not unify:
// 1. checked_type_ or shape_ of the child differs from that of the parent
// In this case, we could be overriding user annotations.
// 2. If the child is a Var and the parent is a DataflowVar.
// That could result in a DataflowVar leaving the current DataflowBlock.
bool annotations_differ = AnnotationsDiffer(var->struct_info_, parent_var->struct_info_,
[&](const ObjectRef& lhs, const ObjectRef& rhs) {
return tvm::StructuralEqual()(lhs, rhs);
});
bool var_to_dataflow = (!var.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
return !annotations_differ && !var_to_dataflow;
Var VisitVarDef(const Var& var) override {
if (auto it = plan_.replace_binding.find(var->vid); it != plan_.replace_binding.end()) {
return ExprMutator::VisitVarDef(it->second);
} else {
return ExprMutator::VisitVarDef(var);
}
}

Expr VisitExpr_(const VarNode* var) override {
if (auto it = plan_.replace_usage.find(var->vid); it != plan_.replace_usage.end()) {
return ExprMutator::VisitExpr(it->second);
} else {
return ExprMutator::VisitExpr_(var);
}
}

private:
CanonicalizationPlan plan_;
};
} // namespace

Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); }
Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); }

namespace transform {

Expand Down
Loading

0 comments on commit ddf30fb

Please sign in to comment.