Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add a utility method to move operation dependencies. #129975

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/ADT/SetVector.h"

namespace mlir {
class DominanceInfo;
class RewriterBase;

/// Check if all values in the provided range are defined above the `limit`
Expand Down Expand Up @@ -69,6 +70,15 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
[](Operation *) { return false; });

/// Move SSA values used within an operation before an insertion point,
/// so that the operation itself (or its replacement) can be moved to
/// the insertion point.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance);
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);

/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
/// elimination, as well as some other DCE. This function returns success if any
Expand Down
66 changes: 66 additions & 0 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/RegionUtils.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -1054,3 +1057,66 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks || droppedRedundantArguments);
}

//===---------------------------------------------------------------------===//
// Move operation dependencies
//===---------------------------------------------------------------------===//

LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance) {
// Currently unsupported case where the op and insertion point are
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
op, "unsupported caes where operation and insertion point are not in "
"the same basic block");
}

// Find the backward slice of operation for each `Value` the operation
// depends on. Prune the slice to only include operations not already
// dominated by the `insertionPoint`
BackwardSliceOptions options;
options.inclusive = true;
options.omitUsesFromAbove = false;
options.filter = [&](Operation *sliceBoundaryOp) {
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
};
llvm::SetVector<Operation *> slice;

// Get the defined slice for operands.
for (Value operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
auto regions = op->getRegions();
if (!regions.empty()) {
// If op has region, get the defined slice for all captured values.
llvm::SetVector<Value> capturedVals;
mlir::getUsedValuesDefinedAbove(regions, capturedVals);
for (Value value : capturedVals) {
getBackwardSlice(value, &slice, options);
}
}

// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
op,
"cannot move dependencies before operation in backward slice of op");
}

// Sort the slice topologically, and move in topological order.
mlir::topologicalSort(slice);
for (Operation *op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint) {
DominanceInfo dominance(op);
return moveOperationDependencies(rewriter, op, insertionPoint, dominance);
}
171 changes: 171 additions & 0 deletions mlir/test/Transforms/move-operation-deps.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s

// Check simple move of dependent operation before insertion.
func.func @simple_move() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() : () -> (f32)
%2 = "foo"(%1) : (f32) -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @simple_move()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Move operands that are implicitly captured by the op
func.func @move_region_dependencies() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() : () -> (f32)
%2 = "foo"() ({
%3 = "inner_op"(%1) : (f32) -> (f32)
"yield"(%3) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}
// CHECK-LABEL: func @move_region_dependencies()
// CHECK: %[[MOVED:.+]] = "moved_op"
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

// Move operations in toplogical sort order
func.func @move_in_topological_sort_order() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op_2"() : () -> (f32)
%3 = "moved_op_3"(%1) : (f32) -> (f32)
%4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32)
%5 = "moved_op_5"(%2) : (f32) -> (f32)
%6 = "foo"(%4, %5) : (f32, f32) -> (f32)
return %6 : f32
}
// CHECK-LABEL: func @move_in_topological_sort_order()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]])
// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]])
// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2"
// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]])
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]])
// CHECK: return %[[FOO]]

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

func.func @move_region_dependencies() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op_1"() : () -> (f32)
%2 = "moved_op"() ({
"yield"(%1) : (f32) -> ()
}) : () -> (f32)
%3 = "foo"() ({
"yield"(%2) : (f32) -> ()
}) : () -> (f32)
return %3 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
// CHECK-LABEL: func @move_region_dependencies()
// CHECK: %[[MOVED_1:.+]] = "moved_op_1"
// CHECK: %[[MOVED_2:.+]] = "moved_op_2"
// CHECK: "yield"(%[[MOVED_1]])
// CHECK: "before"
// CHECK: %[[FOO:.+]] = "foo"
// CHECK: return %[[FOO]]

// -----

// Fail when the "before" operation is part of the operation slice.
func.func @do_not_move_slice() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"(%0) : (f32) -> (f32)
%2 = "foo"(%1) : (f32) -> (f32)
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}

// -----

func.func @move_region_dependencies() -> f32 {
%0 = "before"() : () -> (f32)
%1 = "moved_op"() ({
"yield"(%0) : (f32) -> ()
}) : () -> (f32)
%2 = "foo"() ({
"yield"(%1) : (f32) -> ()
}) : () -> (f32)
return %2 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
8 changes: 8 additions & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
set(LLVM_TARGET_DEFINITIONS TestTransformsOps.td)
mlir_tablegen(TestTransformsOps.h.inc -gen-op-decls)
mlir_tablegen(TestTransformsOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTestTransformsOpsIncGen)

set(LLVM_OPTIONAL_SOURCES
TestDialectConversion.cpp)
set(MLIRTestTransformsPDLDep)
Expand Down Expand Up @@ -25,6 +30,7 @@ add_mlir_library(MLIRTestTransforms
TestControlFlowSink.cpp
TestInlining.cpp
TestMakeIsolatedFromAbove.cpp
TestTransformsOps.cpp
${MLIRTestTransformsPDLSrc}

EXCLUDE_FROM_LIBMLIR
Expand All @@ -34,6 +40,7 @@ add_mlir_library(MLIRTestTransforms

DEPENDS
${MLIRTestTransformsPDLDep}
MLIRTestTransformsOpsIncGen

LINK_LIBS PUBLIC
MLIRTestDialect
Expand All @@ -43,6 +50,7 @@ mlir_target_link_libraries(MLIRTestTransforms PUBLIC
MLIRFuncDialect
MLIRInferIntRangeInterface
MLIRTransforms
MLIRTransformDialect
)

target_include_directories(MLIRTestTransforms
Expand Down
66 changes: 66 additions & 0 deletions mlir/test/lib/Transforms/TestTransformsOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines transform dialect operations for testing MLIR
// transformations
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"

#define GET_OP_CLASSES
#include "TestTransformsOps.h.inc"

using namespace mlir;
using namespace mlir::transform;

#define GET_OP_CLASSES
#include "TestTransformsOps.cpp.inc"

DiagnosedSilenceableFailure
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
TransformResults &TransformResults,
TransformState &state) {
Operation *op = *state.getPayloadOps(getOp()).begin();
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
std::string errorMsg = listener->checkAndResetError().getMessage();
return emitSilenceableFailure(op, errorMsg);
}
return DiagnosedSilenceableFailure::success();
}

namespace {

class TestTransformsDialectExtension
: public transform::TransformDialectExtension<
TestTransformsDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)

using Base::Base;

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "TestTransformsOps.cpp.inc"
>();
}
};
} // namespace

namespace test {
void registerTestTransformsTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<TestTransformsDialectExtension>();
}
} // namespace test
Loading
Loading