Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Transform] Improved canonicalization of non-dataflow Var #15941

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 173 additions & 74 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,68 +33,187 @@
namespace tvm {
namespace relax {

class BindingCanonicalizer : public ExprMutator {
namespace {

struct CanonicalizationPlan {
Map<Id, Var> replace_usage;
Map<Id, Var> replace_binding;
std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
};

/*! \brief Utility class to identify usage location
*
* Canonicalization of a variable binding may require information from
* later in the function. For example, replacing `dataflow_x = expr`
* with `var_x = expr` to avoid a trivial binding of `var_x =
* dataflow_x` later in the function. This utility examines a relax
* expression, and plans the changes to be made in a mutation pass.
*/
class CanonicalizePlanner : public ExprVisitor {
public:
BindingCanonicalizer() {}

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
if (auto explicit_tuple = tuple_value.as<TupleNode>()) {
CHECK_GE(tuple_get_item->index, 0)
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but negative indices are not supported in this context.";
CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size())
<< "Tuple " << tuple_value << " is accessed at index " << tuple_get_item->index
<< ", but the tuple size is only " << explicit_tuple->fields.size();
return VisitExpr(explicit_tuple->fields[tuple_get_item->index]);
static CanonicalizationPlan Collect(const Expr& expr) {
CanonicalizePlanner visitor;
visitor.VisitExpr(expr);

CanonicalizationPlan plan;

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> handled;

for (const auto& binding_iter : visitor.trivial_bindings_) {
Var bound_var = binding_iter.first;
Var bound_to = binding_iter.second;

while (auto opt = visitor.trivial_bindings_.Get(bound_to)) {
// This may be a trivial binding into a trivial binding. In
// that case, unwrap the bindings until we find the earliest
// non-trivial binding.
bound_to = opt.value();
}

while (auto opt = plan.replace_binding.Get(bound_to->vid)) {
// The variable we are binding to may have already been
// replaced, if it fell into Case 4 (Var = DataflowVar). In
// that case, we check against its replacement instead.
bound_to = opt.value();
}

if (bound_var.as<DataflowVarNode>() || !bound_to.as<DataflowVarNode>()) {
// Case 1: Var = Var
// Case 2: DataflowVar = Var
// Case 3: DataflowVar = DataflowVar
//
// For these three cases, the trivial binding can be
// unwrapped, using the bound variable directly at the point
// of use.
plan.replace_usage.Set(bound_var->vid, bound_to);
plan.bindings_to_remove.insert(bound_var->vid);
handled.insert(bound_to);
} else {
// Case 4: Var = DataflowVar
//
// Replacing a Var with a DataflowVar could result in illegal
// use of a DataflowVar outside of a DataflowBlock. Instead,
// we replace in the opposite direction, replacing the binding
// of the DataflowVar with a binding of the Var.
plan.replace_binding.Set(bound_to->vid, bound_var);
plan.replace_usage.Set(bound_to->vid, bound_var);
plan.bindings_to_remove.insert(bound_var->vid);
handled.insert(bound_var);
}
}

// If a Var has been defined inside a DataflowBlock, is only used
// within a DataflowBlock, and is not already handled by removal
// of trivial bindings, then we can replace it with a DataflowVar.
for (const auto& var : visitor.defined_inside_dataflow_) {
if (!var.as<DataflowVarNode>() && !visitor.used_outside_dataflow_.count(var) &&
!handled.count(var)) {
DataflowVar new_var(var->name_hint(), GetStructInfo(var));
plan.replace_binding.Set(var->vid, new_var);
plan.replace_usage.Set(var->vid, new_var);
}
}

return plan;
}

private:
void VisitBindingBlock_(const DataflowBlockNode* block) override {
bool cache = inside_dataflow_;
inside_dataflow_ = true;
ExprVisitor::VisitBindingBlock_(block);
inside_dataflow_ = cache;
}

void VisitBinding(const Binding& binding) override {
bool has_same_struct_info = true;
Expr value;
if (auto ptr = binding.as<VarBindingNode>()) {
value = ptr->value;
} else if (auto ptr = binding.as<MatchCastNode>()) {
has_same_struct_info =
StructuralEqual()(GetStructInfo(binding->var), GetStructInfo(ptr->value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might we want to consider using the arithmetic solver to check if shapes are statically provable to be equal? The struct info can be equivalent without being structurally equal. (We should factor out that checking into a utility function if it isn't already, as this comes up a lot.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to do so, yes, though I think that would be best isolated in an independent PR. Here, the check for having identical struct info is the same check as was done previously in the CanCanonicalizeVar method, and was kept the same in order to minimize the extent of changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. We can revisit this issue separately (I bet there are other passes where we're checking structural equality of struct info instead of static equivalence).

value = ptr->value;
} else {
LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
}

// Unwrap TupleGetItem, if the Tuple being accessed is known.
if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
Expr tuple = tuple_get_item->tuple;
while (auto tuple_var = tuple.as<Var>()) {
if (auto opt = known_bindings_.Get(tuple_var.value())) {
tuple = opt.value();
} else {
break;
}
}

if (auto ptr = tuple.as<TupleNode>()) {
value = ptr->fields[tuple_get_item->index];
}
}

if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
trivial_bindings_.Set(binding->var, parent.value());
}
return ExprMutator::VisitExpr_(tuple_get_item);

known_bindings_.Set(binding->var, value);

ExprVisitor::VisitBinding(binding);
}

void VisitBinding_(const VarBindingNode* binding) override {
// Unlike default visitor, we do not permit the struct info to change
// if the new value's struct info is different (this preserves user annotations)
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

if (auto opt_var = new_value.as<Var>();
opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
var_remap_[new_var->vid] = opt_var.value();
} else if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
void VisitVarDef(const Var& var) override {
if (inside_dataflow_) {
defined_inside_dataflow_.insert(var);
}
}

void VisitExpr_(const VarNode* var) override {
if (!inside_dataflow_) {
used_outside_dataflow_.insert(GetRef<Var>(var));
}
}

bool inside_dataflow_{false};

Map<Var, Var> trivial_bindings_;
Map<Var, Expr> known_bindings_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_inside_dataflow_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_outside_dataflow_;
};

/*! \brief The mutator class to apply a CanonicalizationPlan */
class BindingCanonicalizer : public ExprMutator {
public:
static Expr Apply(Expr expr) {
auto used_outside_dataflow = CanonicalizePlanner::Collect(expr);
BindingCanonicalizer mutator(std::move(used_outside_dataflow));
return mutator.VisitExpr(expr);
}

private:
explicit BindingCanonicalizer(CanonicalizationPlan plan) : plan_(plan) {}

void VisitBinding(const Binding& binding) override {
if (!plan_.bindings_to_remove.count(binding->var->vid)) {
ExprMutator::VisitBinding(binding);
}
}

Var VisitVarDef(const Var& var) override {
if (auto opt = plan_.replace_binding.Get(var->vid)) {
return ExprMutator::VisitVarDef(opt.value());
} else {
this->builder_->EmitNormalized(VarBinding(new_var, new_value));
return ExprMutator::VisitVarDef(var);
}
}

void VisitBinding_(const MatchCastNode* binding) override {
// If we have a trivial shape check (the struct_info_ of LHS and RHS is the same),
// we can canonicalize to a var binding
Expr new_value = this->VisitExpr(binding->value);
bool has_same_struct_info = StructuralEqual()(binding->struct_info, GetStructInfo(new_value));

if (has_same_struct_info) {
if (auto parent = new_value.as<Var>();
parent && CanCanonicalizeVar(binding->var, parent.value())) {
// LHS and RHS have the same struct info, and occur in a
// context where the RHS can replace the LHS.
var_remap_[binding->var->vid] = parent.value();
} else {
// LHS and RHS have the same struct info, but the RHS is not a
// drop-in replacement for the LHS.
builder_->EmitNormalized(VarBinding(binding->var, new_value));
}
} else if (new_value.same_as(binding->value)) {
builder_->EmitNormalized(GetRef<MatchCast>(binding));
Expr VisitExpr_(const VarNode* var) override {
if (auto opt = plan_.replace_usage.Get(var->vid)) {
return ExprMutator::VisitExpr(opt.value());
} else {
// we can't elide in the same way as with var bindings because
// the struct info comparison has semantics
builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info));
return ExprMutator::VisitExpr_(var);
}
}

Expand Down Expand Up @@ -200,31 +319,11 @@ class BindingCanonicalizer : public ExprMutator {
}

private:
bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
std::function<bool(const ObjectRef&, const ObjectRef&)> check_eq) {
// annotations differ if one is present but not the other
// or they're both present and they differ
bool both_present = obj1.defined() && obj2.defined();
bool neither_present = !obj1.defined() && !obj2.defined();
return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2));
}

bool CanCanonicalizeVar(Var var, Var parent_var) {
// Cases when we conservatively do not unify:
// 1. The struct_info_ of the child differs from that of the parent
// In this case, we could be overriding user annotations.
// 2. If the child is a Var and the parent is a DataflowVar.
// That could result in a DataflowVar leaving the current DataflowBlock.
bool annotations_differ = AnnotationsDiffer(var->struct_info_, parent_var->struct_info_,
[&](const ObjectRef& lhs, const ObjectRef& rhs) {
return tvm::StructuralEqual()(lhs, rhs);
});
bool var_to_dataflow = (!var.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
return !annotations_differ && !var_to_dataflow;
}
CanonicalizationPlan plan_;
};
} // namespace

Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); }
Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); }

namespace transform {

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,8 +1403,7 @@ def before(x: R.Tensor((1024,))):
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
with R.dataflow():
a = R.add(x, x)
b = a
b = R.add(x, x)
R.output(b)
return b

Expand Down
9 changes: 3 additions & 6 deletions tests/python/relax/test_optimize_layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,9 @@ def main(
(lv1, lv2),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv2_1
R.output(gv)
return gv

Expand Down Expand Up @@ -256,10 +255,9 @@ def main(
(lv3, lv4),
out_sinfo=R.Tensor((4, 4), dtype="float32"),
)
lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
)
gv: R.Tensor((16,), dtype="float32") = lv6
R.output(gv)
return gv

Expand Down Expand Up @@ -399,10 +397,9 @@ def main(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), dtype="float32"
pad_value=None,
axis_separators=[],
)
lv_2 = R.call_tir(
gv = R.call_tir(
Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), dtype="float32")
)
gv: R.Tensor((14,), dtype="float32") = lv_2
R.output(gv)
return gv

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relax/test_remove_redundant_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def main(
x: R.Tensor((1, 1001, 1, 1), dtype="float16")
) -> R.Tensor((1, 1001), dtype="float16"):
with R.dataflow():
lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
gv: R.Tensor((1, 1001), dtype="float16") = lv
gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, R.shape([1, 1001]))
R.output(gv)
return gv

Expand Down
Loading