Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Implement FNormalize for relax.op.call_tir #16068

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,19 @@ class WellFormedChecker : public relax::ExprVisitor,

if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) {
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
auto before_normalize = GetRef<Call>(call);
auto after_normalize = func_normalize(dummy_builder, before_normalize);
if (!before_normalize.same_as(after_normalize)) {
Call before_normalize = GetRef<Call>(call);
Optional<Expr> after_normalize = NullOpt;
try {
after_normalize = func_normalize(dummy_builder, before_normalize);
} catch (std::exception& err) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization function (FNormalize), "
<< "calls to that operator must be normalized with it. "
<< "However, normalization of " << before_normalize << " resulted in the error: \n"
<< err.what());
}
if (after_normalize && !before_normalize.same_as(after_normalize)) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization function (FNormalize), "
Expand Down
89 changes: 74 additions & 15 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,70 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
return call->sinfo_args[0];
}

Expr NormalizeCallTIR(const BlockBuilder&, Call call) {
// Temporary implementation to ensure that at least one op has a
// registered value for FNormalize. This temporary implementation
// is fully implemented in follow-up PR
// https://github.com/apache/tvm/pull/16068.
Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) {
// This function is used for normalization of `relax.call_tir`,
// along with the variants `relax.call_tir_with_grad` and
// `relax.call_tir_inplace`. Therefore, all error messages should
// be written in terms of `call->op`, and should not explicitly
// reference the `relax.call_tir` operator.`
CHECK(call->args.size() == 2 || call->args.size() == 3)
<< "Operation " << call->op << " expects either two arguments [callee, arg_tuple], "
<< "or three arguments [callee, arg_tuple, tir_args], "
<< "but " << call << " has " << call->args.size() << " arguments.";

Expr arg_expr = call->args[1];

CHECK(arg_expr->struct_info_.as<TupleStructInfoNode>())
<< "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. "
<< "However, the second argument " << arg_expr << " has struct info "
<< arg_expr->struct_info_ << ".";

if (arg_expr.as<TupleNode>()) {
return std::move(call);
}

CHECK(arg_expr.as<VarNode>())
<< "Operation " << call->op << " must hold its arguments as an in-line tuple. "
<< "However, " << call << " has arguments " << arg_expr
<< ", which is neither an in-line tuple, "
<< "nor a variable binding that may be normalized to an in-line tuple.";

auto unwrap_binding = [&ctx](Expr expr) -> Optional<Expr> {
if (auto var = expr.as<Var>()) {
if (auto bound_value = ctx->LookupBinding(var.value())) {
return bound_value.value();
}
}
return NullOpt;
};

while (auto unwrapped = unwrap_binding(arg_expr)) {
arg_expr = unwrapped.value();
}

Tuple new_arg_expr = [&]() {
// Preferred replacement. The argument tuple is provided as a
// variable, but we know the value bound to that variable.
if (auto opt = arg_expr.as<Tuple>()) {
return opt.value();
}

// Fallback case. The argument tuple is provided as a variable,
// and we don't know the value bound to that variable. For
// example, if a relax function accepted a tuple as an parameter,
// then provided that same tuple as an argument to call_tir.
Array<Expr> tuple_elements;
size_t num_fields = Downcast<TupleStructInfo>(arg_expr->struct_info_)->fields.size();
for (size_t i = 0; i < num_fields; i++) {
tuple_elements.push_back(TupleGetItem(arg_expr, i));
}
return Tuple(tuple_elements);
}();

auto new_args = call->args;
new_args.Set(1, new_arg_expr);
call.CopyOnWrite()->args = new_args;

return std::move(call);
}

Expand Down Expand Up @@ -314,6 +373,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad")
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expand Down Expand Up @@ -353,14 +413,12 @@ TVM_REGISTER_GLOBAL("relax.op.call_tir_with_grad").set_body_typed(MakeCallTIRWit

// call_tir_inplace

StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exactly 1 output struct info.");
}
CHECK(call->args[0]->IsInstance<GlobalVarNode>())
<< "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. "
<< "However, gets " << call->args[0];
Expr NormalizeCallTIRInPlace(const BlockBuilder& ctx, Call call) {
// Apply normalization before error checks. This allows the error
// checks to safely apply `Downcast<Tuple>(call->args[1])`, which
// may result in an error if performed before normalization.
call = Downcast<Call>(NormalizeCallTIR(ctx, std::move(call)));

// there must be an inplace index for each output
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
size_t num_outputs = 1U;
Expand Down Expand Up @@ -443,7 +501,7 @@ StructInfo InferStructInfoCallTIRInplace(const Call& call, const BlockBuilder& c
}
}

return call->sinfo_args[0];
return std::move(call);
}

TVM_REGISTER_NODE_TYPE(CallTIRInplaceAttrs);
Expand All @@ -456,7 +514,8 @@ RELAY_REGISTER_OP("relax.call_tir_inplace")
.add_argument("packed_ints", "Expr",
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIRInplace)
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIRInPlace)
// Warning: considered pure, but it has the potential to create visible effects!
// This should only be used if it has been *checked* that it is safe (no aliases, in-place
// arguments will no longer be live)
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relax/test_transform_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,52 @@ def foo(x: R.Tensor((2, 3), dtype="float32")):
verify(Before, Expected)


def test_call_tir_tuple_arg():
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16, 16], "int32"), B: R.Tensor([16, 16], "int32")):
cls = Before
Prod = R.call_tir(cls.product, [A, B], out_sinfo=R.Tensor([16, 16], "int32"))
Sum = R.call_tir(cls.sum, [A, B], out_sinfo=R.Tensor([16, 16], "int32"))
return (Prod, Sum)

@T.prim_func(private=True)
def product(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
):
for iters in T.grid(*A.shape):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] * B[i, j]

@T.prim_func(private=True)
def sum(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
):
for iters in T.grid(*A.shape):
with T.block("compute"):
i, j = T.axis.remap("SS", iters)
C[i, j] = A[i, j] + B[i, j]

Expected = Before

# If EliminateCommonSubexpr produces unnormalized expressions,
# normalization of those expressions may produce additional
# variables bindings. This test case should be agnostic to those
# additional bindings, so DCE is applied after CSE.
After = tvm.ir.transform.Sequential(
[
EliminateCommonSubexpr(),
tvm.relax.transform.DeadCodeElimination(),
]
)(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()
Loading