-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][Transform] Handle relax.Var as call_tir args when lowering #15916
[Unity][Transform] Handle relax.Var as call_tir args when lowering #15916
Conversation
424d179
to
a2f5ed8
Compare
Rebased onto |
I think you may need to update the StructInfo inference for |
Array<Expr> args = [&]() { | ||
if (auto ptr = arg_tuple.as<TupleNode>()) { | ||
return ptr->fields; | ||
} else if (auto ptr = arg_tuple->struct_info_.as<TupleStructInfoNode>()) { | ||
size_t n_args = ptr->fields.size(); | ||
Array<Expr> args; | ||
for (size_t i = 0; i < n_args; i++) { | ||
args.push_back(TupleGetItem(arg_tuple, i)); | ||
} | ||
return args; | ||
} else { | ||
LOG(FATAL) << "Lowering of " << call | ||
<< " requires knowing how many arguments are passed to the function. " | ||
<< "However, the tuple of arguments " << arg_tuple | ||
<< " is not itself a tuple, " | ||
<< "nor does its struct info " << GetStructInfo(arg_tuple) | ||
<< " define the number of arguments."; | ||
} | ||
}(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that I particularly mind it, but is it preferable to use this construction with a lambda as opposed to just assigning args
in different branches?
Array<Expr> args;
if (case1) {
args = ...
} else if (case2) {
args = ...
}
// etc.
Does it avoid an allocation or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to use this construction for a couple of reasons.
- Avoid ever having a partially-initialized variable.
- Limit the scope of temporary variables that should only be used in the initialization.
- Simplify nested if/else cases with early returns.
Effectively, the immediately-invoked lambda expression acts as a block scope with a return type, similar to Rust's braces (e.g. The value of {let i = 5; i+1}
is 6
) or relax's SeqExpr.
python/tvm/relax/op/base.py
Outdated
if ( | ||
isinstance(args, Expr) | ||
and not isinstance(args, RxTuple) | ||
and not isinstance(args.struct_info_, TupleStructInfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should do this for the other call_tir variants in this file too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, and updated.
* \return The value bound to the input \p var. | ||
* \note For function parameters, this function returns NullOpt. | ||
*/ | ||
inline Optional<Expr> LookupBinding(const Var& var) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should instead enforce call_tir to be a more restricted form
Thanks for the PR, I know this is indeed a generalization and there are some tradeoffs to be considered here. Specifically, we should consider the following alternative:
This indeed limit the flexibility in terms of what is valid, but would greatly simplify the logic that leverages CallTIR. And such simplicity helps us in a lot of cases since passes are simpler and more passes that depends on From expressiveness pov, there is nothing to be lost and we only need passes that generate call_tir to be able to explicitly unpack tuples. |
See also this issue. Allowing FWIW, I think our utilities for looking up bindings make it simple enough to deal with cases like not having a tuple literal in |
This does increase the complexity of passes that interact with I think there are ways we could ensure that all
With either of these long-term options, the arguments would be immediately accessible by passes that interact with |
@slyubomirsky I pulled in the infer struct info improvements and the unit test updates from PR#15971, so this should now include all fixes from both sets of changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concerns were addressed, but we should have agreement on the design question @tqchen posed before we merge (if we will merge).
Making a new IR node for Since we have the StructInfo system as a type system, I don't see a reason why |
The principle here is we make common cases(and their optimizations easy), while placing burdens on less common cases. As for the pass writing patterns, most of our current relax passes starts with mutating leaf-node e.g. Call where we can explicit unpack the patterns. They are structured as follows:
Mutating TupleExpr directly have the issue of needing to consider possible recursions, given our normal form allows nested tuple, exactly for the ease of representing structural information in some special intrinsics. It can become really hard to reason for most cases due to the existence of recursion. Likely there is really a limited set of passes that involves this directly mutation (as a matter of fact, I only recall perhaps structural info deduction, which is consistent), especially when they are tied to key optimization needs, and they can be restructured to the above pattern. Taking that rationale and given a lot of instrinsics also follow that pattern(e.g. CallPacked and other possible low level variant (e.g. an intrinsic that explicitly marks read and writes). Supporting intrinsics to allow special structures in the signature helps to simplify passes. It is already part of our design, solving it for CallTIR alone likely won't resolve issues on other calls, and future instrincis that we might need (like explicit read write grouping for multi-stream analysis). Given the CallTIR and other structured intrinsics are central part of a lot of our analysis. I think the tradeoff is worth spending. Such restriction can be checked through well-formness check so they won't propagate into runtime. |
If we intend to have special cases like I would argue that it doesn't bring that many advantages to require the argument to |
Oh, absolutely. The comment was to ensure that we had a clean implementation for the design comparison.
Thank you for this phrasing, as I was trying to put this sentiment into words. The existence of special requirements that must be preserved makes generic implementations of other passes be much more difficult. |
What we are discussing is closely related to our original consideration of what forms a normal form and the general needs. Let me elaborate a bit more The purpose of normal form is to restrict the set of possibilities in representing the same program thus helping reduce the assumption that one might take when writing pass. Of course it imposes extra demand for passes to ensure things go back to the normal form(the wellform-ness check). Currently, the relax normal form requires that we normalize all non-tuple nesting, but expand all tuple constructions. That means we encourage forms like def func(x, y):
lv0 = call_tir(mm, (x, y)) But do not allow the following in normal form def func(x, y):
lv0 = (x, y)
lv1 = call_tir(mm, lv0) The reason we do that is to observe a general category of need:
When constructing intrinsics, we would like intrinsics to present certain structures in them, for example Pattern match and rewriting of these structural intrinsic is a first-class need in our case because of following reasons:
Such a normal form might bring a bit of restrictions, but won’t be too much to other possible passes that are “generic”. This also depends on the past writing patterns we do. Let us consider the following patterns
I would claim that PT0 is actually more desirable and handles almost all the pass needs, and it won’t cause any issues in the structural intrinsics, since all arguments are transformed and pattern matched together. I think most of the possible ideas of difficulty mentioned are related to passes that follow PT1. Under the current normal form however, we naturally need to be very careful about PT1, because we do not want to lift common values in tuple construction and instead encourage tuples to be always unpacked. A CSE pass that tries to detect sub-tuples and lift common values via PT1 actually would violate the normal form requirement. A correct CSE pass under the current normal form should not overload PT1 A rule of thumb under the current normal form is that we almost always won’t do PT1, unless one passes that populate tuple annotation(aka deduction). If there is another pass that does so, we should visit the assumption and check carefully if the result can violate the normal form assumption. If there is really a need to do tuple replacement in a pass, and there is a concern that the result might violate the structural intrincis requirment. One simple approach is to have a common pass(like convertSSA in tir) that goes over the Call, always unpack tuple struct info to their bottom. This pass can be inserted after whatever pass so the assumption is preserved, and we only need to register a property(require structured tuple) in the Finally, one should note that the passes might touch PT1 are usually developed in repo, as a result, the extra consideration of renormalization won’t be hurt, as our experience is sufficient to tackle these considerations. Burdens of R0/R1 are unbounded, a lot of our designs(e.g. Dataflow block) goes into simplifying them for developers who may not have a deep background. Of course all of the above roots back to our rationale of the normal form. One can argue we should take other normal forms(e.g. Tuple value should always be bound to a variable), in which case the consideration would be different and there are other tradeoffs. We considered these tradeoffs before arriving at the current one. |
I think the major issue is the inconsistency of normalization. From the code itself, there is no clear expectation of this being the normal form for Relax, and the majority of the handling in the codebase Behavior suggesting that the de facto normal form allows Tuple variables
Behavior suggesting that the de facto normal form requires in-line Tuple variables
In order for in-line tuples be the normalization used in practice, and not just in principle, the |
Great work in making the full comparison list, Eric. Are we presently using inline tuples to simplify things significantly? I really don't think it's that big a deal to check the type of a variable and determine if you need to use TupleGetItem, etc. |
I haven't found any locations that would require anything more than a binding lookup. Of the 19 instances of
The most complicated case involving |
The simpler approach, is actually to just enable well-form check to detect the related issues and require tuple to be inlined. Normalizer serves as a way to create new bound variables when composite expression occur, which is the common case, but do not necessarily have to handle all the inlining cases. Where pass writer can simply do the items if needed once detected by wellform check. The purpose of inline tuple is to remove the need of bound lookup, and make the structural information in the arguments like Syntactically, that means explicit values like below. def main(x, y):
lv0 = (x, y)
lv1 = call_tir(mm, lv0) The goal is to be aligned with the rationale to support N0: Structural intrinsic argument Again the issue is that the use cases of strutural matching of intrinsics can be unbounded (as many passes are pattern rewriting of structural intrinsics), we would like to simplify these cases both in terms of syntax, as well as the pattern matching code. Aka we would like to eliminate the extra binding lookup for both syntax reason, as well as reduction of metral overhead for having to think about optional binding lookup or polymorphism. This complexity would even get worse if we start to look into nested structure, which we might need in future, it also grows with number of structural instrinsics we add, which goes beyond call_tir itself. In the mean time, the complexity of keeping things normalized(e.g. have util postproc that automatically insert tuple getitem for cases that needs to be unpacked for passes that might generate tuple argument passing), is managable with the clear wellform-ness check. And also likely simplifies most of our needs here. |
Oh, certainly agreed that the well-formed check is the simpler one to introduce in the short-term. I think it adds complexity in the long-term, though.
This is the case I want to avoid, because this is too late of a warning. It occurs after a pass has been designed and written, often only when more comprehensive inputs are used in CI. Adding a new design constraint that late in development often requires re-design of the implementation in order to follow the new constraint, resulting in significant repetition of effort.
I agree with the reasoning of reducing the mental overhead, but think that comes from having as many constraints exposed to the C++ type system as possible. (See earlier comment for ways this could be done.) By making illegal states be impossible to represent, the mental overhead of avoiding illegal states is removed. If illegal states can be represented, we increase the mental overhead of avoiding those illegal states. |
The `ExprMutator` class provides a `LookupBinding` utility for use by subclasses. This commit provides the same functionality to subclasses of `ExprVisitor`.
Prior to this commit, the `relax.transform.FoldConstant` pass assumed that the `""relax.call_tir"` builtin had arguments expressed as `relax.Tuple`, and failed if provided with a `relax.Var` that had been bound to a tuple. This commit updates the `FoldConstant` pass to handle variables annotated with `TupleStructInfo`. If the variable's value was determined within the scope of the mutated function, we can look up the bound tuple and find the argument. If the variable's value was produced as output from another function, then we cannot use it in a constant expression, and must leave it as-is.
Prior to this commit, the `relax.transform.FuseOps` pass assumed that the `""relax.call_tir"` builtin had arguments expressed as `relax.Tuple`, and failed if provided with a `relax.Var` that had been bound to a tuple. This commit updates the `FuseOps` pass to unwrap variable bindings prior to downcasting from `relax::Expr` to `relax::Tuple`.
Prior to this commit, the `relax.transform.RewriteDataflowReshape` pass assumed that the `""relax.call_tir"` builtin had arguments expressed as `relax.Tuple`, and failed if provided with a `relax.Var` that had been bound to a tuple. This commit updates the `RewriteDataflowReshape` pass to handle variables annotated with `TupleStructInfo`. The identification of a reshape can be done with only the number of arguments, which can be extracted from the variables `TupleStructInfo` instead of requiring a `Tuple`. If the TIR function is a reshape, then the tuple variable can be unwrapped to a known tuple in order to find the argument, or can use a `TupleGetItem` node to extract the argument.
8edb8bd
to
5089f47
Compare
This is indeed also a tradeoff here in terms of IR design. We cannot encode all constraints, for example the ANF itself is not encoded directly in the c++ data structure, but still enforced through well-form checkers. For this particular case, because we can possibly allow unbounded usecases to support N0: Structural intrinsic argument. The possible usecases go beyond CallTIR, that also include call packed, possible nested structures etc. So enabling it in the normal form is a good tradeoff
A pass that follows the common pattern(rewrites Call and not rewrites Tuple) will not face such issues. A pass that needs to do tuple replacement already have other complications to consider. And if the pass writer is not willing to do so, having a followup util function(like ConvertSSA) will simply unpack the tuple values and recovert things back to the normal form, without need to redesign the original pass. Indeed, this extra burden is when we that develops pass that involves tuple remapping, such pass in nature would require extra set of considerations. And the extra overhead of insertin a re-normalization is not high. They are less common than the passes that involves pattern match and rewrite structural intrinsic argument. That is also why the design prioritizes these cases. |
@slyubomirsky @tqchen Can you take a look at PRs #16067 and #16068? The former introduces new functionality through a
Properties (2) and (3) fulfill my goals of avoiding an increased mental burden when writing an upstream pass, and of making all IR requirements be explicit within the code. |
Prior to this commit, the different `R.call_tir*` variations would wrap the arguments into an in-line `relax.Tuple`, if it is not already a `relax.Tuple`. While this allows a tensor to be passed into these functions as a single argument (`R.call_tir(func, arg, ...)` instead of `R.call_tir(func, [arg], ...)`), the wrapped Relax variable may already refer to a tuple. This use of a variable to refer to an argument tuple rather than an in-line argument tuple is not allowed by Relax. (See discussion on apache#15916 for details.) However, by wrapping a variable `args: R.Tuple(R.Tensor, R.Tensor, ...)` into a tuple-of-tuples, the error occurs after the expression has already been generated, and refers to an expression `R.Tuple(R.Tuple(R.Tensor, R.Tensor, ...))` that doesn't appear anywhere in the user's input. This can make debugging difficult (see apache#17239 for an example). This commit updates the argument-handling in `R.call_tir` to only generate an in-line `relax.Tuple` if the arguments do not already have `relax.TupleStructInfo`. If the argument was provided as a Relax variable bound to a tuple of arguments, it will still produce an error. However, that error will occur much earlier, and will explicitly state that the argument must be a `relax.Tuple` instead of a `relax.Var`.
…17243) * [Relax] Avoid wrapping TupleStructInfo into a Tuple for R.call_tir Prior to this commit, the different `R.call_tir*` variations would wrap the arguments into an in-line `relax.Tuple`, if it is not already a `relax.Tuple`. While this allows a tensor to be passed into these functions as a single argument (`R.call_tir(func, arg, ...)` instead of `R.call_tir(func, [arg], ...)`), the wrapped Relax variable may already refer to a tuple. This use of a variable to refer to an argument tuple rather than an in-line argument tuple is not allowed by Relax. (See discussion on #15916 for details.) However, by wrapping a variable `args: R.Tuple(R.Tensor, R.Tensor, ...)` into a tuple-of-tuples, the error occurs after the expression has already been generated, and refers to an expression `R.Tuple(R.Tuple(R.Tensor, R.Tensor, ...))` that doesn't appear anywhere in the user's input. This can make debugging difficult (see #17239 for an example). This commit updates the argument-handling in `R.call_tir` to only generate an in-line `relax.Tuple` if the arguments do not already have `relax.TupleStructInfo`. If the argument was provided as a Relax variable bound to a tuple of arguments, it will still produce an error. However, that error will occur much earlier, and will explicitly state that the argument must be a `relax.Tuple` instead of a `relax.Var`. * lint fixes
Prior to this commit, several transforms assumed that the arguments passed to a
call_tir
builtin were provided as in-linerelax::Tuple
objects. Because it would be equally valid for the arguments to instead be arelax::Var
instance that had previously been bound to arelax::Tuple
object, or had been passed as an input parameter withrelax::TupleStructInfo
, this assumption shouldn't be made. This PR updates theCallTIRRewrite
,FoldConstant
,FuseOps
, andRewriteDataflowReshape
passes to handle variables providing the arguments.