Skip to content

Commit 205c532

Browse files
[mlir] Add a utility method to move operation dependencies. (#129975)
The added utility method moves all SSA values that an operation depends upon before an insertion point. This is useful during transformations where such movements might make transformations (like fusion) more powerful. To test the operation add a transform dialect op that calls the move operation. To be able to capture the `notifyMatchFailure` messages from the transformation and to report/check these in the test modify the `ErrorCheckingTrackingListener` to capture the last match failure notification. --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 5bf0486 commit 205c532

File tree

10 files changed

+452
-0
lines changed

10 files changed

+452
-0
lines changed

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

+11
Original file line numberDiff line numberDiff line change
@@ -1074,10 +1074,18 @@ class ErrorCheckingTrackingListener : public TrackingListener {
10741074
/// resets the error state to "success".
10751075
DiagnosedSilenceableFailure checkAndResetError();
10761076

1077+
/// Return the latest match notification message. Returns an empty string
1078+
/// when no error message was captured.
1079+
std::string getLatestMatchFailureMessage();
1080+
10771081
/// Return "true" if this tracking listener had a failure.
10781082
bool failed() const;
10791083

10801084
protected:
1085+
void
1086+
notifyMatchFailure(Location loc,
1087+
function_ref<void(Diagnostic &)> reasonCallback) override;
1088+
10811089
void
10821090
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
10831091
DiagnosedSilenceableFailure &&diag) override;
@@ -1089,6 +1097,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
10891097

10901098
/// The number of errors that have been encountered.
10911099
int64_t errorCounter = 0;
1100+
1101+
/// Latest message from match failure notification.
1102+
std::optional<Diagnostic> matchFailure;
10921103
};
10931104

10941105
/// This is a special rewriter to be used in transform op implementations,

mlir/include/mlir/Transforms/RegionUtils.h

+11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/ADT/SetVector.h"
1616

1717
namespace mlir {
18+
class DominanceInfo;
1819
class RewriterBase;
1920

2021
/// Check if all values in the provided range are defined above the `limit`
@@ -69,6 +70,16 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
6970
llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
7071
[](Operation *) { return false; });
7172

73+
/// Move SSA values used within an operation before an insertion point,
74+
/// so that the operation itself (or its replacement) can be moved to
75+
/// the insertion point. Current support is only for movement of
76+
/// dependencies of `op` before `insertionPoint` in the same basic block.
77+
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
78+
Operation *insertionPoint,
79+
DominanceInfo &dominance);
80+
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
81+
Operation *insertionPoint);
82+
7283
/// Run a set of structural simplifications over the given regions. This
7384
/// includes transformations like unreachable block elimination, dead argument
7485
/// elimination, as well as some other DCE. This function returns success if any

mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,21 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
13901390
++errorCounter;
13911391
}
13921392

1393+
std::string
1394+
transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() {
1395+
if (!matchFailure) {
1396+
return "";
1397+
}
1398+
return matchFailure->str();
1399+
}
1400+
1401+
void transform::ErrorCheckingTrackingListener::notifyMatchFailure(
1402+
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1403+
Diagnostic diag(loc, DiagnosticSeverity::Remark);
1404+
reasonCallback(diag);
1405+
matchFailure = std::move(diag);
1406+
}
1407+
13931408
//===----------------------------------------------------------------------===//
13941409
// TransformRewriter
13951410
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/RegionUtils.cpp

+61
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Transforms/RegionUtils.h"
10+
11+
#include "mlir/Analysis/SliceAnalysis.h"
1012
#include "mlir/Analysis/TopologicalSortUtils.h"
1113
#include "mlir/IR/Block.h"
1214
#include "mlir/IR/BuiltinOps.h"
15+
#include "mlir/IR/Dominance.h"
1316
#include "mlir/IR/IRMapping.h"
1417
#include "mlir/IR/Operation.h"
1518
#include "mlir/IR/PatternMatch.h"
@@ -1054,3 +1057,61 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
10541057
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
10551058
mergedIdenticalBlocks || droppedRedundantArguments);
10561059
}
1060+
1061+
//===---------------------------------------------------------------------===//
1062+
// Move operation dependencies
1063+
//===---------------------------------------------------------------------===//
1064+
1065+
LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
1066+
Operation *op,
1067+
Operation *insertionPoint,
1068+
DominanceInfo &dominance) {
1069+
// Currently unsupported case where the op and insertion point are
1070+
// in different basic blocks.
1071+
if (op->getBlock() != insertionPoint->getBlock()) {
1072+
return rewriter.notifyMatchFailure(
1073+
op, "unsupported caes where operation and insertion point are not in "
1074+
"the same basic block");
1075+
}
1076+
// If `insertionPoint` does not dominate `op`, do nothing
1077+
if (!dominance.properlyDominates(insertionPoint, op)) {
1078+
return rewriter.notifyMatchFailure(op,
1079+
"insertion point does not dominate op");
1080+
}
1081+
1082+
// Find the backward slice of operation for each `Value` the operation
1083+
// depends on. Prune the slice to only include operations not already
1084+
// dominated by the `insertionPoint`
1085+
BackwardSliceOptions options;
1086+
options.inclusive = false;
1087+
options.omitUsesFromAbove = false;
1088+
// Since current support is to only move within a same basic block,
1089+
// the slices dont need to look past block arguments.
1090+
options.omitBlockArguments = true;
1091+
options.filter = [&](Operation *sliceBoundaryOp) {
1092+
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
1093+
};
1094+
llvm::SetVector<Operation *> slice;
1095+
getBackwardSlice(op, &slice, options);
1096+
1097+
// If the slice contains `insertionPoint` cannot move the dependencies.
1098+
if (slice.contains(insertionPoint)) {
1099+
return rewriter.notifyMatchFailure(
1100+
op,
1101+
"cannot move dependencies before operation in backward slice of op");
1102+
}
1103+
1104+
// We should move the slice in topological order, but `getBackwardSlice`
1105+
// already does that. So no need to sort again.
1106+
for (Operation *op : slice) {
1107+
rewriter.moveOpBefore(op, insertionPoint);
1108+
}
1109+
return success();
1110+
}
1111+
1112+
LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
1113+
Operation *op,
1114+
Operation *insertionPoint) {
1115+
DominanceInfo dominance(op);
1116+
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
1117+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
2+
3+
// Check simple move of dependent operation before insertion.
4+
func.func @simple_move() -> f32 {
5+
%0 = "before"() : () -> (f32)
6+
%1 = "moved_op"() : () -> (f32)
7+
%2 = "foo"(%1) : (f32) -> (f32)
8+
return %2 : f32
9+
}
10+
// CHECK-LABEL: func @simple_move()
11+
// CHECK: %[[MOVED:.+]] = "moved_op"
12+
// CHECK: %[[BEFORE:.+]] = "before"
13+
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED]])
14+
// CHECK: return %[[FOO]]
15+
16+
module attributes {transform.with_named_sequence} {
17+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
18+
%op1 = transform.structured.match ops{["foo"]} in %arg0
19+
: (!transform.any_op) -> !transform.any_op
20+
%op2 = transform.structured.match ops{["before"]} in %arg0
21+
: (!transform.any_op) -> !transform.any_op
22+
transform.test.move_operand_deps %op1 before %op2
23+
: !transform.any_op, !transform.any_op
24+
transform.yield
25+
}
26+
}
27+
28+
// -----
29+
30+
// Move operands that are implicitly captured by the op
31+
func.func @move_region_dependencies() -> f32 {
32+
%0 = "before"() : () -> (f32)
33+
%1 = "moved_op"() : () -> (f32)
34+
%2 = "foo"() ({
35+
%3 = "inner_op"(%1) : (f32) -> (f32)
36+
"yield"(%3) : (f32) -> ()
37+
}) : () -> (f32)
38+
return %2 : f32
39+
}
40+
// CHECK-LABEL: func @move_region_dependencies()
41+
// CHECK: %[[MOVED:.+]] = "moved_op"
42+
// CHECK: %[[BEFORE:.+]] = "before"
43+
// CHECK: %[[FOO:.+]] = "foo"
44+
// CHECK: return %[[FOO]]
45+
46+
module attributes {transform.with_named_sequence} {
47+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
48+
%op1 = transform.structured.match ops{["foo"]} in %arg0
49+
: (!transform.any_op) -> !transform.any_op
50+
%op2 = transform.structured.match ops{["before"]} in %arg0
51+
: (!transform.any_op) -> !transform.any_op
52+
transform.test.move_operand_deps %op1 before %op2
53+
: !transform.any_op, !transform.any_op
54+
transform.yield
55+
}
56+
}
57+
58+
// -----
59+
60+
// Move operations in toplogical sort order
61+
func.func @move_in_topological_sort_order() -> f32 {
62+
%0 = "before"() : () -> (f32)
63+
%1 = "moved_op_1"() : () -> (f32)
64+
%2 = "moved_op_2"() : () -> (f32)
65+
%3 = "moved_op_3"(%1) : (f32) -> (f32)
66+
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
67+
%5 = "moved_op_5"(%2) : (f32) -> (f32)
68+
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
69+
return %6 : f32
70+
}
71+
// CHECK-LABEL: func @move_in_topological_sort_order()
72+
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
73+
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
74+
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
75+
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
76+
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
77+
// CHECK: %[[BEFORE:.+]] = "before"
78+
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
79+
// CHECK: return %[[FOO]]
80+
81+
module attributes {transform.with_named_sequence} {
82+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
83+
%op1 = transform.structured.match ops{["foo"]} in %arg0
84+
: (!transform.any_op) -> !transform.any_op
85+
%op2 = transform.structured.match ops{["before"]} in %arg0
86+
: (!transform.any_op) -> !transform.any_op
87+
transform.test.move_operand_deps %op1 before %op2
88+
: !transform.any_op, !transform.any_op
89+
transform.yield
90+
}
91+
}
92+
93+
// -----
94+
95+
func.func @move_region_dependencies() -> f32 {
96+
%0 = "before"() : () -> (f32)
97+
%1 = "moved_op_1"() : () -> (f32)
98+
%2 = "moved_op_2"() ({
99+
"yield"(%1) : (f32) -> ()
100+
}) : () -> (f32)
101+
%3 = "foo"() ({
102+
"yield"(%2) : (f32) -> ()
103+
}) : () -> (f32)
104+
return %3 : f32
105+
}
106+
107+
module attributes {transform.with_named_sequence} {
108+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
109+
%op1 = transform.structured.match ops{["foo"]} in %arg0
110+
: (!transform.any_op) -> !transform.any_op
111+
%op2 = transform.structured.match ops{["before"]} in %arg0
112+
: (!transform.any_op) -> !transform.any_op
113+
transform.test.move_operand_deps %op1 before %op2
114+
: !transform.any_op, !transform.any_op
115+
transform.yield
116+
}
117+
}
118+
// CHECK-LABEL: func @move_region_dependencies()
119+
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
120+
// CHECK: %[[MOVED_2:.+]] = "moved_op_2"
121+
// CHECK: "yield"(%[[MOVED_1]])
122+
// CHECK: "before"
123+
// CHECK: %[[FOO:.+]] = "foo"
124+
// CHECK: return %[[FOO]]
125+
126+
// -----
127+
128+
// Current implementation omits following basic block argument when
129+
// computing slices. Verify that this gives expected result.
130+
func.func @ignore_basic_block_arguments() -> f32 {
131+
^bb0():
132+
%8 = "test"() : () -> (f32)
133+
return %8: f32
134+
^bb1(%bbArg : f32):
135+
%0 = "before"() : () -> (f32)
136+
%1 = "moved_op"() ({
137+
"yield"(%bbArg) : (f32) -> ()
138+
}) : () -> (f32)
139+
%2 = "foo"() ({
140+
"yield"(%1) : (f32) -> ()
141+
}) : () -> (f32)
142+
return %2 : f32
143+
}
144+
145+
module attributes {transform.with_named_sequence} {
146+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
147+
%op1 = transform.structured.match ops{["foo"]} in %arg0
148+
: (!transform.any_op) -> !transform.any_op
149+
%op2 = transform.structured.match ops{["before"]} in %arg0
150+
: (!transform.any_op) -> !transform.any_op
151+
transform.test.move_operand_deps %op1 before %op2
152+
: !transform.any_op, !transform.any_op
153+
transform.yield
154+
}
155+
}
156+
// CHECK-LABEL: func @ignore_basic_block_arguments()
157+
// CHECK: %[[MOVED_1:.+]] = "moved_op"
158+
// CHECK: "before"
159+
// CHECK: %[[FOO:.+]] = "foo"
160+
// CHECK: return %[[FOO]]
161+
162+
// -----
163+
164+
// Fail when the "before" operation is part of the operation slice.
165+
func.func @do_not_move_slice() -> f32 {
166+
%0 = "before"() : () -> (f32)
167+
%1 = "moved_op"(%0) : (f32) -> (f32)
168+
%2 = "foo"(%1) : (f32) -> (f32)
169+
return %2 : f32
170+
}
171+
172+
module attributes {transform.with_named_sequence} {
173+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
174+
%op1 = transform.structured.match ops{["foo"]} in %arg0
175+
: (!transform.any_op) -> !transform.any_op
176+
%op2 = transform.structured.match ops{["before"]} in %arg0
177+
: (!transform.any_op) -> !transform.any_op
178+
// expected-remark@+1{{cannot move dependencies before operation in backward slice of op}}
179+
transform.test.move_operand_deps %op1 before %op2
180+
: !transform.any_op, !transform.any_op
181+
transform.yield
182+
}
183+
}
184+
185+
// -----
186+
187+
// Fail when the "before" operation is part of the operation slice (computed
188+
// when looking through implicitly captured values).
189+
func.func @do_not_move_slice() -> f32 {
190+
%0 = "before"() : () -> (f32)
191+
%1 = "moved_op"() ({
192+
"yield"(%0) : (f32) -> ()
193+
}) : () -> (f32)
194+
%2 = "foo"() ({
195+
"yield"(%1) : (f32) -> ()
196+
}) : () -> (f32)
197+
return %2 : f32
198+
}
199+
200+
module attributes {transform.with_named_sequence} {
201+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
202+
%op1 = transform.structured.match ops{["foo"]} in %arg0
203+
: (!transform.any_op) -> !transform.any_op
204+
%op2 = transform.structured.match ops{["before"]} in %arg0
205+
: (!transform.any_op) -> !transform.any_op
206+
// expected-remark@+1{{cannot move dependencies before operation in backward slice of op}}
207+
transform.test.move_operand_deps %op1 before %op2
208+
: !transform.any_op, !transform.any_op
209+
transform.yield
210+
}
211+
}
212+
213+
// -----
214+
215+
// Dont move ops when insertion point does not dominate the op
216+
func.func @do_not_move() -> f32 {
217+
%1 = "moved_op"() : () -> (f32)
218+
%2 = "foo"() ({
219+
"yield"(%1) : (f32) -> ()
220+
}) : () -> (f32)
221+
%3 = "before"() : () -> f32
222+
return %2 : f32
223+
}
224+
225+
module attributes {transform.with_named_sequence} {
226+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
227+
%op1 = transform.structured.match ops{["foo"]} in %arg0
228+
: (!transform.any_op) -> !transform.any_op
229+
%op2 = transform.structured.match ops{["before"]} in %arg0
230+
: (!transform.any_op) -> !transform.any_op
231+
// expected-remark@+1{{insertion point does not dominate op}}
232+
transform.test.move_operand_deps %op1 before %op2
233+
: !transform.any_op, !transform.any_op
234+
transform.yield
235+
}
236+
}

0 commit comments

Comments
 (0)