Skip to content

Commit

Permalink
[mlir] Add a utility method to move operation dependencies. (llvm#129975
Browse files Browse the repository at this point in the history
)

The added utility method moves all SSA values that an operation depends
upon before an insertion point. This is useful during transformations
where such movements might make transformations (like fusion) more
powerful.

To test the operation add a transform dialect op that calls the move
operation. To be able to capture the `notifyMatchFailure` messages from
the transformation and to report/check these in the test modify the
`ErrorCheckingTrackingListener` to capture the last match failure
notification.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar authored Mar 11, 2025
1 parent 5bf0486 commit 205c532
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1074,10 +1074,18 @@ class ErrorCheckingTrackingListener : public TrackingListener {
/// resets the error state to "success".
DiagnosedSilenceableFailure checkAndResetError();

/// Return the latest match notification message. Returns an empty string
/// when no error message was captured.
std::string getLatestMatchFailureMessage();

/// Return "true" if this tracking listener had a failure.
bool failed() const;

protected:
void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;

void
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
DiagnosedSilenceableFailure &&diag) override;
Expand All @@ -1089,6 +1097,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {

/// The number of errors that have been encountered.
int64_t errorCounter = 0;

/// Latest message from match failure notification.
std::optional<Diagnostic> matchFailure;
};

/// This is a special rewriter to be used in transform op implementations,
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/ADT/SetVector.h"

namespace mlir {
class DominanceInfo;
class RewriterBase;

/// Check if all values in the provided range are defined above the `limit`
Expand Down Expand Up @@ -69,6 +70,16 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
[](Operation *) { return false; });

/// Move SSA values used within an operation before an insertion point,
/// so that the operation itself (or its replacement) can be moved to
/// the insertion point. Current support is only for movement of
/// dependencies of `op` before `insertionPoint` in the same basic block.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance);
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);

/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
/// elimination, as well as some other DCE. This function returns success if any
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,21 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
++errorCounter;
}

std::string
transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() {
if (!matchFailure) {
return "";
}
return matchFailure->str();
}

void transform::ErrorCheckingTrackingListener::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
matchFailure = std::move(diag);
}

//===----------------------------------------------------------------------===//
// TransformRewriter
//===----------------------------------------------------------------------===//
Expand Down
61 changes: 61 additions & 0 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/RegionUtils.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -1054,3 +1057,61 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks || droppedRedundantArguments);
}

//===---------------------------------------------------------------------===//
// Move operation dependencies
//===---------------------------------------------------------------------===//

LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance) {
// Currently unsupported case where the op and insertion point are
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
op, "unsupported caes where operation and insertion point are not in "
"the same basic block");
}
// If `insertionPoint` does not dominate `op`, do nothing
if (!dominance.properlyDominates(insertionPoint, op)) {
return rewriter.notifyMatchFailure(op,
"insertion point does not dominate op");
}

// Find the backward slice of operation for each `Value` the operation
// depends on. Prune the slice to only include operations not already
// dominated by the `insertionPoint`
BackwardSliceOptions options;
options.inclusive = false;
options.omitUsesFromAbove = false;
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
options.filter = [&](Operation *sliceBoundaryOp) {
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);

// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
op,
"cannot move dependencies before operation in backward slice of op");
}

// We should move the slice in topological order, but `getBackwardSlice`
// already does that. So no need to sort again.
for (Operation *op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint) {
DominanceInfo dominance(op);
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
}
236 changes: 236 additions & 0 deletions mlir/test/Transforms/move-operation-deps.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s

// Check simple move of dependent operation before insertion.
func.func @simple_move() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() : () -> (f32)
%2 = "foo"(%1) : (f32) -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @simple_move()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Move operands that are implicitly captured by the op
func.func @move_region_dependencies() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() : () -> (f32)
%2 = "foo"() ({
%3 = "inner_op"(%1) : (f32) -> (f32)
"yield"(%3) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @move_region_dependencies()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Move operations in toplogical sort order
func.func @move_in_topological_sort_order() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "moved_op_3"(%1) : (f32) -> (f32)
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
%5 = "moved_op_5"(%2) : (f32) -> (f32)
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
return %6 : f32
}
// CHECK-LABEL: func @move_in_topological_sort_order()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

func.func @move_region_dependencies() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() ({
"yield"(%1) : (f32) -> ()
}) : () -> (f32)
%3 = "foo"() ({
"yield"(%2) : (f32) -> ()
}) : () -> (f32)
return %3 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func @move_region_dependencies()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK: %[[MOVED_2:.+]] = "moved_op_2"
// CHECK: "yield"(%[[MOVED_1]])
// CHECK: "before"
// CHECK: %[[FOO:.+]] = "foo"
// CHECK: return %[[FOO]]

// -----

// Current implementation omits following basic block argument when
// computing slices. Verify that this gives expected result.
func.func @ignore_basic_block_arguments() -> f32 {
^bb0():
%8 = "test"() : () -> (f32)
return %8: f32
^bb1(%bbArg : f32):
%0 = "before"() : () -> (f32)
%1 = "moved_op"() ({
"yield"(%bbArg) : (f32) -> ()
}) : () -> (f32)
%2 = "foo"() ({
"yield"(%1) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func @ignore_basic_block_arguments()
// CHECK: %[[MOVED_1:.+]] = "moved_op"
// CHECK: "before"
// CHECK: %[[FOO:.+]] = "foo"
// CHECK: return %[[FOO]]

// -----

// Fail when the "before" operation is part of the operation slice.
func.func @do_not_move_slice() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"(%0) : (f32) -> (f32)
%2 = "foo"(%1) : (f32) -> (f32)
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
// expected-remark@+1{{cannot move dependencies before operation in backward slice of op}}
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Fail when the "before" operation is part of the operation slice (computed
// when looking through implicitly captured values).
func.func @do_not_move_slice() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() ({
"yield"(%0) : (f32) -> ()
}) : () -> (f32)
%2 = "foo"() ({
"yield"(%1) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
// expected-remark@+1{{cannot move dependencies before operation in backward slice of op}}
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Dont move ops when insertion point does not dominate the op
func.func @do_not_move() -> f32 {
%1 = "moved_op"() : () -> (f32)
%2 = "foo"() ({
"yield"(%1) : (f32) -> ()
}) : () -> (f32)
%3 = "before"() : () -> f32
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
// expected-remark@+1{{insertion point does not dominate op}}
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
Loading

0 comments on commit 205c532

Please sign in to comment.