Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add a utility method to move operation dependencies. #129975

Conversation

MaheshRavishankar
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar commented Mar 6, 2025

The added utility method moves all SSA values that an operation depends upon before an insertion point. This is useful during transformations where such movements might make transformations (like fusion) more powerful.

To test the operation add a transform dialect op that calls the move operation. To be able to capture the notifyMatchFailure messages from the transformation and to report/check these in the test modify the ErrorCheckingTrackingListener to capture the last match failure notification.

The added utility method moves all SSA values that an operation
depends upon before an insertion point. This is useful during
transformations where such movements might make transformations (like
fusion) more powerful.

To test the operation add a transform dialect op that calls the move
operation.

Signed-off-by: MaheshRavishankar <[email protected]>
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir-core

Author: None (MaheshRavishankar)

Changes

The added utility method moves all SSA values that an operation depends upon before an insertion point. This is useful during transformations where such movements might make transformations (like fusion) more powerful.

To test the operation add a transform dialect op that calls the move operation.


Full diff: https://github.com/llvm/llvm-project/pull/129975.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+10)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+65)
  • (added) mlir/test/Transforms/move-operation-deps.mlir (+113)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+8)
  • (added) mlir/test/lib/Transforms/TestTransformsOps.cpp (+66)
  • (added) mlir/test/lib/Transforms/TestTransformsOps.td (+41)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 5c57dd5b7532a..4acc8528efe97 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SetVector.h"
 
 namespace mlir {
+class DominanceInfo;
 class RewriterBase;
 
 /// Check if all values in the provided range are defined above the `limit`
@@ -69,6 +70,15 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
         [](Operation *) { return false; });
 
+/// Move SSA values used within an operation before an insertion point,
+/// so that the operation itself (or its replacement) can be moved to
+/// the insertion point.
+LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
+                                        Operation *insertionPoint,
+                                        DominanceInfo &dominance);
+LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
+                                        Operation *insertionPoint);
+
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e55ef6eb66b9c..7040243bed83b 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -7,9 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/RegionUtils.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1054,3 +1057,65 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks || droppedRedundantArguments);
 }
+
+//===---------------------------------------------------------------------===//
+// Move operation dependencies
+//===---------------------------------------------------------------------===//
+
+LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
+                                              Operation *op,
+                                              Operation *insertionPoint,
+                                              DominanceInfo &dominance) {
+  // Currently unsupported case where the op and insertion point are
+  // in different basic blocks.
+  if (op->getBlock() != insertionPoint->getBlock()) {
+    return rewriter.notifyMatchFailure(
+        op, "unsupported caes where operation and insertion point are not in "
+            "the sme basic block");
+  }
+
+  // Find the backward slice of operation for each `Value` the operation
+  // depends on. Prune the slice to only include operations not already
+  // dominated by the `insertionPoint`
+  BackwardSliceOptions options;
+  options.inclusive = true;
+  options.filter = [&](Operation *sliceBoundaryOp) {
+    return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+  };
+  llvm::SetVector<Operation *> slice;
+
+  // Get the defined slice for operands.
+  for (Value operand : op->getOperands()) {
+    getBackwardSlice(operand, &slice, options);
+  }
+  auto regions = op->getRegions();
+  if (!regions.empty()) {
+    // If op has region, get the defined slice for all captured values.
+    llvm::SetVector<Value> capturedVals;
+    mlir::getUsedValuesDefinedAbove(regions, capturedVals);
+    for (auto value : capturedVals) {
+      getBackwardSlice(value, &slice, options);
+    }
+  }
+
+  // If the slice contains `insertionPoint` cannot move the dependencies.
+  if (slice.contains(insertionPoint)) {
+    return rewriter.notifyMatchFailure(
+        op,
+        "cannot move dependencies before operation in backward slice of op");
+  }
+
+  // Sort the slice topologically, ad move in topological order.
+  mlir::topologicalSort(slice);
+  for (auto op : slice) {
+    rewriter.moveOpBefore(op, insertionPoint);
+  }
+  return success();
+}
+
+LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
+                                              Operation *op,
+                                              Operation *insertionPoint) {
+  DominanceInfo dominance(op);
+  return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
+}
\ No newline at end of file
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
new file mode 100644
index 0000000000000..90c66a0f14938
--- /dev/null
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s
+
+// Check simple move of dependent operation before insertion.
+func.func @simple_move() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op"() : () -> (f32)
+  %2 = "foo"(%1) : (f32) -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @simple_move()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operands that are implicitly captured by the op
+func.func @move_region_dependencies() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op"() : () -> (f32)
+  %2 = "foo"() ({
+    %3 = "inner_op"(%1) : (f32) -> (f32)
+    "yield"(%3) : (f32) -> ()
+  }) : () -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_in_topological_sort_order() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "moved_op_3"(%1) : (f32) -> (f32)
+  %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+  %5 = "moved_op_5"(%2) : (f32) -> (f32)
+  %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+  return %6 : f32
+}
+// CHECK-LABEL: func @move_in_topological_sort_order()
+//       CHECK:   %[[MOVED_1:.+]] = "moved_op_1"
+//   CHECK-DAG:   %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+//   CHECK-DAG:   %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+//   CHECK-DAG:   %[[MOVED_4:.+]] = "moved_op_2"
+//   CHECK-DAG:   %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Fail when the "before" operation is part of the operation slice.
+func.func @do_not_move_slice() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op"(%0) : (f32) -> (f32)
+  %2 = "foo"(%1) : (f32) -> (f32)
+  return %2 : f32
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 1b9b9bffa5279..c053fd4b20473 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_TARGET_DEFINITIONS TestTransformsOps.td)
+mlir_tablegen(TestTransformsOps.h.inc -gen-op-decls)
+mlir_tablegen(TestTransformsOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTestTransformsOpsIncGen)
+
 set(LLVM_OPTIONAL_SOURCES 
   TestDialectConversion.cpp)
 set(MLIRTestTransformsPDLDep)
@@ -25,6 +30,7 @@ add_mlir_library(MLIRTestTransforms
   TestControlFlowSink.cpp
   TestInlining.cpp
   TestMakeIsolatedFromAbove.cpp
+  TestTransformsOps.cpp
   ${MLIRTestTransformsPDLSrc}
 
   EXCLUDE_FROM_LIBMLIR
@@ -34,6 +40,7 @@ add_mlir_library(MLIRTestTransforms
 
   DEPENDS
   ${MLIRTestTransformsPDLDep}
+  MLIRTestTransformsOpsIncGen
 
   LINK_LIBS PUBLIC
   MLIRTestDialect
@@ -43,6 +50,7 @@ mlir_target_link_libraries(MLIRTestTransforms PUBLIC
   MLIRFuncDialect
   MLIRInferIntRangeInterface
   MLIRTransforms
+  MLIRTransformDialect
   )
 
 target_include_directories(MLIRTestTransforms
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
new file mode 100644
index 0000000000000..427930b0c7ed1
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -0,0 +1,66 @@
+//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines transform dialect operations for testing MLIR
+// transformations
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+#define GET_OP_CLASSES
+#include "TestTransformsOps.h.inc"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+#define GET_OP_CLASSES
+#include "TestTransformsOps.cpp.inc"
+
+DiagnosedSilenceableFailure
+transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
+                                      TransformResults &TransformResults,
+                                      TransformState &state) {
+  Operation *op = *state.getPayloadOps(getOp()).begin();
+  Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+  if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
+    auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+    std::string errorMsg = listener->checkAndResetError().getMessage();
+    return emitSilenceableFailure(op, errorMsg);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+namespace {
+
+class TestTransformsDialectExtension
+    : public transform::TransformDialectExtension<
+          TestTransformsDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
+
+  using Base::Base;
+
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "TestTransformsOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+namespace test {
+void registerTestTransformsTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<TestTransformsDialectExtension>();
+}
+} // namespace test
\ No newline at end of file
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
new file mode 100644
index 0000000000000..ef19d00f999c3
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -0,0 +1,41 @@
+//===- TestTransformOps.td ---------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_TRANSFORM_OPS
+#define TEST_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+/// Transform dialect perations for testing transformations in MLIR
+
+def TestMoveOperandDeps :
+    Op<Transform_Dialect, "test.move_operand_deps",
+        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+         DeclareOpInterfaceMethods<TransformOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Moves all dependencies of on operation before another operation.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$op,
+         TransformHandleTypeInterface:$insertion_point);
+  
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $op `before` $insertion_point attr-dict
+    `:` type($op) `,` type($insertion_point)
+  }];
+}
+
+#endif // TEST_TRANSFORM_OPS
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f18ad45dfb708..d06ff8070e7cf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -170,6 +170,7 @@ void registerTestDialect(DialectRegistry &);
 void registerTestDynDialect(DialectRegistry &);
 void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &);
 void registerTestTransformDialectExtension(DialectRegistry &);
+void registerTestTransformsTransformDialectExtension(DialectRegistry &);
 } // namespace test
 
 #ifdef MLIR_INCLUDE_TESTS
@@ -323,6 +324,7 @@ int main(int argc, char **argv) {
 #ifdef MLIR_INCLUDE_TESTS
   ::test::registerTestDialect(registry);
   ::test::registerTestTransformDialectExtension(registry);
+  ::test::registerTestTransformsTransformDialectExtension(registry);
   ::test::registerTestTilingInterfaceTransformDialectExtension(registry);
   ::test::registerTestDynDialect(registry);
 #endif

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

The added utility method moves all SSA values that an operation depends upon before an insertion point. This is useful during transformations where such movements might make transformations (like fusion) more powerful.

To test the operation add a transform dialect op that calls the move operation.


Full diff: https://github.com/llvm/llvm-project/pull/129975.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+10)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+65)
  • (added) mlir/test/Transforms/move-operation-deps.mlir (+113)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+8)
  • (added) mlir/test/lib/Transforms/TestTransformsOps.cpp (+66)
  • (added) mlir/test/lib/Transforms/TestTransformsOps.td (+41)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 5c57dd5b7532a..4acc8528efe97 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -15,6 +15,7 @@
 #include "llvm/ADT/SetVector.h"
 
 namespace mlir {
+class DominanceInfo;
 class RewriterBase;
 
 /// Check if all values in the provided range are defined above the `limit`
@@ -69,6 +70,15 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
     llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
         [](Operation *) { return false; });
 
+/// Move SSA values used within an operation before an insertion point,
+/// so that the operation itself (or its replacement) can be moved to
+/// the insertion point.
+LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
+                                        Operation *insertionPoint,
+                                        DominanceInfo &dominance);
+LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
+                                        Operation *insertionPoint);
+
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
 /// elimination, as well as some other DCE. This function returns success if any
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e55ef6eb66b9c..7040243bed83b 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -7,9 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/RegionUtils.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1054,3 +1057,65 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks || droppedRedundantArguments);
 }
+
+//===---------------------------------------------------------------------===//
+// Move operation dependencies
+//===---------------------------------------------------------------------===//
+
+LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
+                                              Operation *op,
+                                              Operation *insertionPoint,
+                                              DominanceInfo &dominance) {
+  // Currently unsupported case where the op and insertion point are
+  // in different basic blocks.
+  if (op->getBlock() != insertionPoint->getBlock()) {
+    return rewriter.notifyMatchFailure(
+        op, "unsupported caes where operation and insertion point are not in "
+            "the sme basic block");
+  }
+
+  // Find the backward slice of operation for each `Value` the operation
+  // depends on. Prune the slice to only include operations not already
+  // dominated by the `insertionPoint`
+  BackwardSliceOptions options;
+  options.inclusive = true;
+  options.filter = [&](Operation *sliceBoundaryOp) {
+    return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+  };
+  llvm::SetVector<Operation *> slice;
+
+  // Get the defined slice for operands.
+  for (Value operand : op->getOperands()) {
+    getBackwardSlice(operand, &slice, options);
+  }
+  auto regions = op->getRegions();
+  if (!regions.empty()) {
+    // If op has region, get the defined slice for all captured values.
+    llvm::SetVector<Value> capturedVals;
+    mlir::getUsedValuesDefinedAbove(regions, capturedVals);
+    for (auto value : capturedVals) {
+      getBackwardSlice(value, &slice, options);
+    }
+  }
+
+  // If the slice contains `insertionPoint` cannot move the dependencies.
+  if (slice.contains(insertionPoint)) {
+    return rewriter.notifyMatchFailure(
+        op,
+        "cannot move dependencies before operation in backward slice of op");
+  }
+
+  // Sort the slice topologically, ad move in topological order.
+  mlir::topologicalSort(slice);
+  for (auto op : slice) {
+    rewriter.moveOpBefore(op, insertionPoint);
+  }
+  return success();
+}
+
+LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
+                                              Operation *op,
+                                              Operation *insertionPoint) {
+  DominanceInfo dominance(op);
+  return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
+}
\ No newline at end of file
diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir
new file mode 100644
index 0000000000000..90c66a0f14938
--- /dev/null
+++ b/mlir/test/Transforms/move-operation-deps.mlir
@@ -0,0 +1,113 @@
+// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s
+
+// Check simple move of dependent operation before insertion.
+func.func @simple_move() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op"() : () -> (f32)
+  %2 = "foo"(%1) : (f32) -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @simple_move()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operands that are implicitly captured by the op
+func.func @move_region_dependencies() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op"() : () -> (f32)
+  %2 = "foo"() ({
+    %3 = "inner_op"(%1) : (f32) -> (f32)
+    "yield"(%3) : (f32) -> ()
+  }) : () -> (f32)
+  return %2 : f32
+}
+// CHECK-LABEL: func @move_region_dependencies()
+//       CHECK:   %[[MOVED:.+]] = "moved_op"
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Move operations in toplogical sort order
+func.func @move_in_topological_sort_order() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op_1"() : () -> (f32)
+  %2 = "moved_op_2"() : () -> (f32)
+  %3 = "moved_op_3"(%1) : (f32) -> (f32)
+  %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
+  %5 = "moved_op_5"(%2) : (f32) -> (f32)
+  %6 = "foo"(%4, %5) : (f32, f32) -> (f32)
+  return %6 : f32
+}
+// CHECK-LABEL: func @move_in_topological_sort_order()
+//       CHECK:   %[[MOVED_1:.+]] = "moved_op_1"
+//   CHECK-DAG:   %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
+//   CHECK-DAG:   %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
+//   CHECK-DAG:   %[[MOVED_4:.+]] = "moved_op_2"
+//   CHECK-DAG:   %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
+//       CHECK:   %[[BEFORE:.+]] = "before"
+//       CHECK:   %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
+//       CHECK:   return %[[FOO]]
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Fail when the "before" operation is part of the operation slice.
+func.func @do_not_move_slice() -> f32 {
+  %0 = "before"() : () -> (f32)
+  %1 = "moved_op"(%0) : (f32) -> (f32)
+  %2 = "foo"(%1) : (f32) -> (f32)
+  return %2 : f32
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %op1 = transform.structured.match ops{["foo"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    %op2 = transform.structured.match ops{["before"]} in %arg0
+        : (!transform.any_op) -> !transform.any_op
+    transform.test.move_operand_deps %op1 before %op2
+        : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index 1b9b9bffa5279..c053fd4b20473 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_TARGET_DEFINITIONS TestTransformsOps.td)
+mlir_tablegen(TestTransformsOps.h.inc -gen-op-decls)
+mlir_tablegen(TestTransformsOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTestTransformsOpsIncGen)
+
 set(LLVM_OPTIONAL_SOURCES 
   TestDialectConversion.cpp)
 set(MLIRTestTransformsPDLDep)
@@ -25,6 +30,7 @@ add_mlir_library(MLIRTestTransforms
   TestControlFlowSink.cpp
   TestInlining.cpp
   TestMakeIsolatedFromAbove.cpp
+  TestTransformsOps.cpp
   ${MLIRTestTransformsPDLSrc}
 
   EXCLUDE_FROM_LIBMLIR
@@ -34,6 +40,7 @@ add_mlir_library(MLIRTestTransforms
 
   DEPENDS
   ${MLIRTestTransformsPDLDep}
+  MLIRTestTransformsOpsIncGen
 
   LINK_LIBS PUBLIC
   MLIRTestDialect
@@ -43,6 +50,7 @@ mlir_target_link_libraries(MLIRTestTransforms PUBLIC
   MLIRFuncDialect
   MLIRInferIntRangeInterface
   MLIRTransforms
+  MLIRTransformDialect
   )
 
 target_include_directories(MLIRTestTransforms
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp
new file mode 100644
index 0000000000000..427930b0c7ed1
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp
@@ -0,0 +1,66 @@
+//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines transform dialect operations for testing MLIR
+// transformations
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Transforms/RegionUtils.h"
+
+#define GET_OP_CLASSES
+#include "TestTransformsOps.h.inc"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+#define GET_OP_CLASSES
+#include "TestTransformsOps.cpp.inc"
+
+DiagnosedSilenceableFailure
+transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
+                                      TransformResults &TransformResults,
+                                      TransformState &state) {
+  Operation *op = *state.getPayloadOps(getOp()).begin();
+  Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
+  if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
+    auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
+    std::string errorMsg = listener->checkAndResetError().getMessage();
+    return emitSilenceableFailure(op, errorMsg);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+namespace {
+
+class TestTransformsDialectExtension
+    : public transform::TransformDialectExtension<
+          TestTransformsDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
+
+  using Base::Base;
+
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "TestTransformsOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+namespace test {
+void registerTestTransformsTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<TestTransformsDialectExtension>();
+}
+} // namespace test
\ No newline at end of file
diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td
new file mode 100644
index 0000000000000..ef19d00f999c3
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestTransformsOps.td
@@ -0,0 +1,41 @@
+//===- TestTransformOps.td ---------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_TRANSFORM_OPS
+#define TEST_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+/// Transform dialect perations for testing transformations in MLIR
+
+def TestMoveOperandDeps :
+    Op<Transform_Dialect, "test.move_operand_deps",
+        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+         DeclareOpInterfaceMethods<TransformOpInterface>,
+         ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Moves all dependencies of on operation before another operation.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$op,
+         TransformHandleTypeInterface:$insertion_point);
+  
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $op `before` $insertion_point attr-dict
+    `:` type($op) `,` type($insertion_point)
+  }];
+}
+
+#endif // TEST_TRANSFORM_OPS
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f18ad45dfb708..d06ff8070e7cf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -170,6 +170,7 @@ void registerTestDialect(DialectRegistry &);
 void registerTestDynDialect(DialectRegistry &);
 void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &);
 void registerTestTransformDialectExtension(DialectRegistry &);
+void registerTestTransformsTransformDialectExtension(DialectRegistry &);
 } // namespace test
 
 #ifdef MLIR_INCLUDE_TESTS
@@ -323,6 +324,7 @@ int main(int argc, char **argv) {
 #ifdef MLIR_INCLUDE_TESTS
   ::test::registerTestDialect(registry);
   ::test::registerTestTransformDialectExtension(registry);
+  ::test::registerTestTransformsTransformDialectExtension(registry);
   ::test::registerTestTilingInterfaceTransformDialectExtension(registry);
   ::test::registerTestDynDialect(registry);
 #endif

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few drive-by nits

Signed-off-by: MaheshRavishankar <[email protected]>
…nd use it to test failure in the op.

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Copy link
Contributor

@IanWood1 IanWood1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I'm a bit unsure about the changes to ErrorCheckingTrackingListener. Maybe you could create a listener in TestTransformsOps.cpp and attach it to the rewriter (I'm not entirely sure this is possible)? Then you could emit the remark directly from the added listener.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/moveOperationDeps branch from 221decd to ac9e0d1 Compare March 10, 2025 18:02
@MaheshRavishankar
Copy link
Contributor Author

@ftynse I made the change you recommended. Ill land this and if you have any concerns, ill address those post landing.

@MaheshRavishankar MaheshRavishankar merged commit 205c532 into llvm:main Mar 11, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants