diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h index e51aac02936b5..b9f2af22e9483 100644 --- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h @@ -1074,10 +1074,18 @@ class ErrorCheckingTrackingListener : public TrackingListener { /// resets the error state to "success". DiagnosedSilenceableFailure checkAndResetError(); + /// Return the latest match notification message. Returns an empty string + /// when no error message was captured. + std::string getLatestMatchFailureMessage(); + /// Return "true" if this tracking listener had a failure. bool failed() const; protected: + void + notifyMatchFailure(Location loc, + function_ref reasonCallback) override; + void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override; @@ -1089,6 +1097,9 @@ class ErrorCheckingTrackingListener : public TrackingListener { /// The number of errors that have been encountered. int64_t errorCounter = 0; + + /// Latest message from match failure notification. + std::optional matchFailure; }; /// This is a special rewriter to be used in transform op implementations, diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index 5c57dd5b7532a..e6b928d8ebecc 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -15,6 +15,7 @@ #include "llvm/ADT/SetVector.h" namespace mlir { +class DominanceInfo; class RewriterBase; /// Check if all values in the provided range are defined above the `limit` @@ -69,6 +70,16 @@ SmallVector makeRegionIsolatedFromAbove( llvm::function_ref cloneOperationIntoRegion = [](Operation *) { return false; }); +/// Move SSA values used within an operation before an insertion point, +/// so that the operation itself (or its replacement) can be moved to +/// the insertion point. Current support is only for movement of +/// dependencies of `op` before `insertionPoint` in the same basic block. +LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, + Operation *insertionPoint, + DominanceInfo &dominance); +LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, + Operation *insertionPoint); + /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 1e0ef5add358e..e0a5df0c758b3 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -1390,6 +1390,21 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( ++errorCounter; } +std::string +transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() { + if (!matchFailure) { + return ""; + } + return matchFailure->str(); +} + +void transform::ErrorCheckingTrackingListener::notifyMatchFailure( + Location loc, function_ref reasonCallback) { + Diagnostic diag(loc, DiagnosticSeverity::Remark); + reasonCallback(diag); + matchFailure = std::move(diag); +} + //===----------------------------------------------------------------------===// // TransformRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index e55ef6eb66b9c..da0d486f0fdcb 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -7,9 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/RegionUtils.h" + +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -1054,3 +1057,61 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, return success(eliminatedBlocks || eliminatedOpsOrArgs || mergedIdenticalBlocks || droppedRedundantArguments); } + +//===---------------------------------------------------------------------===// +// Move operation dependencies +//===---------------------------------------------------------------------===// + +LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, + Operation *op, + Operation *insertionPoint, + DominanceInfo &dominance) { + // Currently unsupported case where the op and insertion point are + // in different basic blocks. + if (op->getBlock() != insertionPoint->getBlock()) { + return rewriter.notifyMatchFailure( + op, "unsupported caes where operation and insertion point are not in " + "the same basic block"); + } + // If `insertionPoint` does not dominate `op`, do nothing + if (!dominance.properlyDominates(insertionPoint, op)) { + return rewriter.notifyMatchFailure(op, + "insertion point does not dominate op"); + } + + // Find the backward slice of operation for each `Value` the operation + // depends on. Prune the slice to only include operations not already + // dominated by the `insertionPoint` + BackwardSliceOptions options; + options.inclusive = false; + options.omitUsesFromAbove = false; + // Since current support is to only move within a same basic block, + // the slices dont need to look past block arguments. + options.omitBlockArguments = true; + options.filter = [&](Operation *sliceBoundaryOp) { + return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + }; + llvm::SetVector slice; + getBackwardSlice(op, &slice, options); + + // If the slice contains `insertionPoint` cannot move the dependencies. + if (slice.contains(insertionPoint)) { + return rewriter.notifyMatchFailure( + op, + "cannot move dependencies before operation in backward slice of op"); + } + + // We should move the slice in topological order, but `getBackwardSlice` + // already does that. So no need to sort again. + for (Operation *op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + +LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, + Operation *op, + Operation *insertionPoint) { + DominanceInfo dominance(op); + return moveOperationDependencies(rewriter, op, insertionPoint, dominance); +} diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir new file mode 100644 index 0000000000000..37637152938f6 --- /dev/null +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -0,0 +1,236 @@ +// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s + +// Check simple move of dependent operation before insertion. +func.func @simple_move() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"() : () -> (f32) + %2 = "foo"(%1) : (f32) -> (f32) + return %2 : f32 +} +// CHECK-LABEL: func @simple_move() +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED]]) +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Move operands that are implicitly captured by the op +func.func @move_region_dependencies() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"() : () -> (f32) + %2 = "foo"() ({ + %3 = "inner_op"(%1) : (f32) -> (f32) + "yield"(%3) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} +// CHECK-LABEL: func @move_region_dependencies() +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo" +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Move operations in toplogical sort order +func.func @move_in_topological_sort_order() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op_2"() : () -> (f32) + %3 = "moved_op_3"(%1) : (f32) -> (f32) + %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32) + %5 = "moved_op_5"(%2) : (f32) -> (f32) + %6 = "foo"(%4, %5) : (f32, f32) -> (f32) + return %6 : f32 +} +// CHECK-LABEL: func @move_in_topological_sort_order() +// CHECK: %[[MOVED_1:.+]] = "moved_op_1" +// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]]) +// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]]) +// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2" +// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]]) +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]]) +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +func.func @move_region_dependencies() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op_2"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + %3 = "foo"() ({ + "yield"(%2) : (f32) -> () + }) : () -> (f32) + return %3 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} +// CHECK-LABEL: func @move_region_dependencies() +// CHECK: %[[MOVED_1:.+]] = "moved_op_1" +// CHECK: %[[MOVED_2:.+]] = "moved_op_2" +// CHECK: "yield"(%[[MOVED_1]]) +// CHECK: "before" +// CHECK: %[[FOO:.+]] = "foo" +// CHECK: return %[[FOO]] + +// ----- + +// Current implementation omits following basic block argument when +// computing slices. Verify that this gives expected result. +func.func @ignore_basic_block_arguments() -> f32 { +^bb0(): + %8 = "test"() : () -> (f32) + return %8: f32 +^bb1(%bbArg : f32): + %0 = "before"() : () -> (f32) + %1 = "moved_op"() ({ + "yield"(%bbArg) : (f32) -> () + }) : () -> (f32) + %2 = "foo"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} +// CHECK-LABEL: func @ignore_basic_block_arguments() +// CHECK: %[[MOVED_1:.+]] = "moved_op" +// CHECK: "before" +// CHECK: %[[FOO:.+]] = "foo" +// CHECK: return %[[FOO]] + +// ----- + +// Fail when the "before" operation is part of the operation slice. +func.func @do_not_move_slice() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"(%0) : (f32) -> (f32) + %2 = "foo"(%1) : (f32) -> (f32) + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + // expected-remark@+1{{cannot move dependencies before operation in backward slice of op}} + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Fail when the "before" operation is part of the operation slice (computed +// when looking through implicitly captured values). +func.func @do_not_move_slice() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op"() ({ + "yield"(%0) : (f32) -> () + }) : () -> (f32) + %2 = "foo"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + // expected-remark@+1{{cannot move dependencies before operation in backward slice of op}} + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} + +// ----- + +// Dont move ops when insertion point does not dominate the op +func.func @do_not_move() -> f32 { + %1 = "moved_op"() : () -> (f32) + %2 = "foo"() ({ + "yield"(%1) : (f32) -> () + }) : () -> (f32) + %3 = "before"() : () -> f32 + return %2 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["foo"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + // expected-remark@+1{{insertion point does not dominate op}} + transform.test.move_operand_deps %op1 before %op2 + : !transform.any_op, !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 1b9b9bffa5279..c053fd4b20473 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_TARGET_DEFINITIONS TestTransformsOps.td) +mlir_tablegen(TestTransformsOps.h.inc -gen-op-decls) +mlir_tablegen(TestTransformsOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTestTransformsOpsIncGen) + set(LLVM_OPTIONAL_SOURCES TestDialectConversion.cpp) set(MLIRTestTransformsPDLDep) @@ -25,6 +30,7 @@ add_mlir_library(MLIRTestTransforms TestControlFlowSink.cpp TestInlining.cpp TestMakeIsolatedFromAbove.cpp + TestTransformsOps.cpp ${MLIRTestTransformsPDLSrc} EXCLUDE_FROM_LIBMLIR @@ -34,6 +40,7 @@ add_mlir_library(MLIRTestTransforms DEPENDS ${MLIRTestTransformsPDLDep} + MLIRTestTransformsOpsIncGen LINK_LIBS PUBLIC MLIRTestDialect @@ -43,6 +50,7 @@ mlir_target_link_libraries(MLIRTestTransforms PUBLIC MLIRFuncDialect MLIRInferIntRangeInterface MLIRTransforms + MLIRTransformDialect ) target_include_directories(MLIRTestTransforms diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp new file mode 100644 index 0000000000000..aaa566d9938a3 --- /dev/null +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -0,0 +1,66 @@ +//===- TestTransformsOps.cpp - Test Transforms ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines transform dialect operations for testing MLIR +// transformations +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Transforms/RegionUtils.h" + +#define GET_OP_CLASSES +#include "TestTransformsOps.h.inc" + +using namespace mlir; +using namespace mlir::transform; + +#define GET_OP_CLASSES +#include "TestTransformsOps.cpp.inc" + +DiagnosedSilenceableFailure +transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter, + TransformResults &TransformResults, + TransformState &state) { + Operation *op = *state.getPayloadOps(getOp()).begin(); + Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin(); + if (failed(moveOperationDependencies(rewriter, op, moveBefore))) { + auto listener = cast(rewriter.getListener()); + std::string errorMsg = listener->getLatestMatchFailureMessage(); + (void)emitRemark(errorMsg); + } + return DiagnosedSilenceableFailure::success(); +} + +namespace { + +class TestTransformsDialectExtension + : public transform::TransformDialectExtension< + TestTransformsDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension) + + using Base::Base; + + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "TestTransformsOps.cpp.inc" + >(); + } +}; +} // namespace + +namespace test { +void registerTestTransformsTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} +} // namespace test diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td new file mode 100644 index 0000000000000..f514702cef5bc --- /dev/null +++ b/mlir/test/lib/Transforms/TestTransformsOps.td @@ -0,0 +1,41 @@ +//===- TestTransformOps.td ---------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_TRANSFORM_OPS +#define TEST_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +/// Transform dialect operations for testing transformations in MLIR + +def TestMoveOperandDeps : + Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Moves all dependencies of on operation before another operation. + }]; + + let arguments = + (ins TransformHandleTypeInterface:$op, + TransformHandleTypeInterface:$insertion_point); + + let results = (outs); + + let assemblyFormat = [{ + $op `before` $insertion_point attr-dict + `:` type($op) `,` type($insertion_point) + }]; +} + +#endif // TEST_TRANSFORM_OPS diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg index 8ffccee1d6d79..7f4d25f1ba025 100644 --- a/mlir/test/lib/Transforms/lit.local.cfg +++ b/mlir/test/lib/Transforms/lit.local.cfg @@ -1 +1,2 @@ config.suffixes.remove(".pdll") +config.suffixes.remove(".td") diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index f18ad45dfb708..d06ff8070e7cf 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -170,6 +170,7 @@ void registerTestDialect(DialectRegistry &); void registerTestDynDialect(DialectRegistry &); void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); +void registerTestTransformsTransformDialectExtension(DialectRegistry &); } // namespace test #ifdef MLIR_INCLUDE_TESTS @@ -323,6 +324,7 @@ int main(int argc, char **argv) { #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); + ::test::registerTestTransformsTransformDialectExtension(registry); ::test::registerTestTilingInterfaceTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); #endif