Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Fold vector.extract from poison vector #126122

Merged
merged 2 commits into from
Feb 7, 2025

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Feb 6, 2025

This PR adds a folder for vector.extract(ub.poison) -> ub.poison. It also replaces create with createOrFold insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.

This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`.
It also replaces `create` with `createOrFold` insert/extract ops in
vector unroll and transpose lowering patterns to trigger the poison
foldings introduced recently.
@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This PR adds a folder for vector.extract(ub.poison) -> ub.poison. It also replaces create with createOrFold insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.


Full diff: https://github.com/llvm/llvm-project/pull/126122.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+14-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+31-24)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+53)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 30ff2df7c38fc34..b4a5461f4405dcf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
 
 /// Fold an insert or extract operation into an poison value when a poison index
 /// is found at any dimension of the static position.
-static ub::PoisonAttr
-foldPoisonIndexInsertExtractOp(MLIRContext *context,
-                               ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
+                                                ArrayRef<int64_t> staticPos,
+                                                int64_t poisonVal) {
   if (!llvm::is_contained(staticPos, poisonVal))
-    return ub::PoisonAttr();
+    return {};
 
   return ub::PoisonAttr::get(context);
 }
 
+/// Fold a vector extract from is a poison source.
+static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
+  if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
+    return srcAttr;
+
+  return {};
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
+  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..6135a1290d559f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
   ImplicitLocOpBuilder b(source.getLoc(), builder);
   SmallVector<Value> vs;
   for (int64_t i = 0; i < m; ++i)
-    vs.push_back(b.create<vector::ExtractOp>(source, i));
+    vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
 
   // Interleave 32-bit lanes using
   //   8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       SmallVector<int64_t> insertIdxs(extractIdxs);
       applyPermutationToVector(insertIdxs, prunedTransp);
       Value extractOp =
-          rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
-      result =
-          rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+          rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
+      result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
+                                                       insertIdxs);
     }
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 800c1d9fb1dbfd6..c1e3850f05c5ec7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
           readOp.getInBoundsAttr());
 
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, slicedRead, result, elementOffsets, strides);
     }
     rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
     Value resultTensor;
     for (SmallVector<int64_t> elementOffsets :
          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
-      Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
       SmallVector<Value> indices =
           sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
         SmallVector<int64_t> operandShape = applyPermutationMap(
             permutationMap, ArrayRef<int64_t>(*targetShape));
         SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
-        slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, operand, operandOffets, operandShape, operandStrides);
+        slicesOperands[index] =
+            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+                loc, operand, operandOffets, operandShape, operandStrides);
       };
 
       // Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
     for (const auto &it : accCache) {
       SmallVector<int64_t> dstStrides(it.first.size(), 1);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, it.second, result, it.first, dstStrides);
     }
     rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
          StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<Value> operands;
       SmallVector<int64_t> operandStrides(offsets.size(), 1);
-      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+      Value slicedOperand =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, reductionOp.getSource(), offsets, *targetShape,
+              operandStrides);
       operands.push_back(slicedOperand);
       SmallVector<int64_t> dstShape;
       SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
       if (accIt != accCache.end())
         acc = accIt->second;
       else
-        acc = rewriter.create<vector::ExtractStridedSliceOp>(
+        acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
       operands.push_back(acc);
       auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
         rewriter.getZeroAttr(reductionOp.getDestType()));
     for (const auto &it : accCache) {
       SmallVector<int64_t> dstStrides(it.first.size(), 1);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, it.second, result, it.first, dstStrides);
     }
     rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
           continue;
         }
         extractOperands.push_back(
-            rewriter.create<vector::ExtractStridedSliceOp>(
+            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
                 loc, operand.get(), offsets, *targetShape, strides));
       }
       Operation *newOp = cloneOpWithOperandsAndTypes(
           rewriter, loc, op, extractOperands, newVecType);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, newOp->getResult(0), result, offsets, strides);
     }
     rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<int64_t> strides(offsets.size(), 1);
-      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionOp.getVector(), offsets, *targetShape, strides);
+      Value slicedOperand =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, reductionOp.getVector(), offsets, *targetShape, strides);
       Operation *newOp = cloneOpWithOperandsAndTypes(
           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
       Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
         permutedShape[indices.value()] = (*targetShape)[indices.index()];
       }
-      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, transposeOp.getVector(), permutedOffsets, permutedShape,
-          strides);
-      Value transposedSlice =
-          rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      Value slicedOperand =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, transposeOp.getVector(), permutedOffsets, permutedShape,
+              strides);
+      Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
+          loc, slicedOperand, permutation);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, transposedSlice, result, elementOffsets, strides);
     }
     rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
       // To get the unrolled gather, extract the same slice based on the
       // decomposed shape from each of the index, mask, and pass-through
       // vectors.
-      Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
-      Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
-      Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
+      Value passThruSubVec =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
+              strides);
       auto slicedGather = rewriter.create<vector::GatherOp>(
           loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
           indexSubVec, maskSubVec, passThruSubVec);
 
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, slicedGather, result, elementOffsets, strides);
     }
     rewriter.replaceOp(gatherOp, result);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61e858f5f226a13..d016c2efa142628 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,28 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
 
 // -----
 
+// CHECK-LABEL: @extract_scalar_poison
+func.func @extract_scalar_poison() -> f32 {
+  // CHECK-NEXT: ub.poison : f32
+  //  CHECK-NOT: vector.extract
+  %0 = ub.poison : vector<4x8xf32>
+  %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison
+func.func @extract_vector_poison() -> vector<8xf32> {
+  // CHECK-NEXT: ub.poison : vector<8xf32>
+  //  CHECK-NOT: vector.extract
+  %0 = ub.poison : vector<4x8xf32>
+  %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
+  return %1 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_scalar_poison_idx
 func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
   //  CHECK-NOT: vector.extract
@@ -2886,6 +2908,37 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
   return %1 : vector<4xi8>
 }
 
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_scalar_poison
+func.func @insert_scalar_poison(%a: vector<4x8xf32>)
+    -> vector<4x8xf32> {
+  //  CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
+  //  CHECK-NEXT: vector.insert %[[UB]]
+  %0 = ub.poison : f32
+  %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
+  return %1 : vector<4x8xf32>
+}
+
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_vector_poison
+func.func @insert_vector_poison(%a: vector<4x8xf32>)
+    -> vector<4x8xf32> {
+  //  CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
+  //  CHECK-NEXT: vector.insert %[[UB]]
+  %0 = ub.poison : vector<8xf32>
+  %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
+  return %1 : vector<4x8xf32>
+}
+
+
 // -----
 
 // CHECK-LABEL: @insert_scalar_poison_idx

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds a folder for vector.extract(ub.poison) -&gt; ub.poison. It also replaces create with createOrFold insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.


Full diff: https://github.com/llvm/llvm-project/pull/126122.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+14-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+31-24)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+53)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 30ff2df7c38fc34..b4a5461f4405dcf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
 
 /// Fold an insert or extract operation into an poison value when a poison index
 /// is found at any dimension of the static position.
-static ub::PoisonAttr
-foldPoisonIndexInsertExtractOp(MLIRContext *context,
-                               ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
+                                                ArrayRef<int64_t> staticPos,
+                                                int64_t poisonVal) {
   if (!llvm::is_contained(staticPos, poisonVal))
-    return ub::PoisonAttr();
+    return {};
 
   return ub::PoisonAttr::get(context);
 }
 
+/// Fold a vector extract from is a poison source.
+static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
+  if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
+    return srcAttr;
+
+  return {};
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
+  if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..6135a1290d559f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
   ImplicitLocOpBuilder b(source.getLoc(), builder);
   SmallVector<Value> vs;
   for (int64_t i = 0; i < m; ++i)
-    vs.push_back(b.create<vector::ExtractOp>(source, i));
+    vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
 
   // Interleave 32-bit lanes using
   //   8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       SmallVector<int64_t> insertIdxs(extractIdxs);
       applyPermutationToVector(insertIdxs, prunedTransp);
       Value extractOp =
-          rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
-      result =
-          rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+          rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
+      result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
+                                                       insertIdxs);
     }
 
     rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 800c1d9fb1dbfd6..c1e3850f05c5ec7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
           readOp.getInBoundsAttr());
 
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, slicedRead, result, elementOffsets, strides);
     }
     rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
     Value resultTensor;
     for (SmallVector<int64_t> elementOffsets :
          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
-      Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
       SmallVector<Value> indices =
           sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
         SmallVector<int64_t> operandShape = applyPermutationMap(
             permutationMap, ArrayRef<int64_t>(*targetShape));
         SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
-        slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, operand, operandOffets, operandShape, operandStrides);
+        slicesOperands[index] =
+            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+                loc, operand, operandOffets, operandShape, operandStrides);
       };
 
       // Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
     for (const auto &it : accCache) {
       SmallVector<int64_t> dstStrides(it.first.size(), 1);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, it.second, result, it.first, dstStrides);
     }
     rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
          StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<Value> operands;
       SmallVector<int64_t> operandStrides(offsets.size(), 1);
-      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+      Value slicedOperand =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, reductionOp.getSource(), offsets, *targetShape,
+              operandStrides);
       operands.push_back(slicedOperand);
       SmallVector<int64_t> dstShape;
       SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
       if (accIt != accCache.end())
         acc = accIt->second;
       else
-        acc = rewriter.create<vector::ExtractStridedSliceOp>(
+        acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
             loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
       operands.push_back(acc);
       auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
         rewriter.getZeroAttr(reductionOp.getDestType()));
     for (const auto &it : accCache) {
       SmallVector<int64_t> dstStrides(it.first.size(), 1);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, it.second, result, it.first, dstStrides);
     }
     rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
           continue;
         }
         extractOperands.push_back(
-            rewriter.create<vector::ExtractStridedSliceOp>(
+            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
                 loc, operand.get(), offsets, *targetShape, strides));
       }
       Operation *newOp = cloneOpWithOperandsAndTypes(
           rewriter, loc, op, extractOperands, newVecType);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, newOp->getResult(0), result, offsets, strides);
     }
     rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
     for (SmallVector<int64_t> offsets :
          StaticTileOffsetRange(originalSize, *targetShape)) {
       SmallVector<int64_t> strides(offsets.size(), 1);
-      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, reductionOp.getVector(), offsets, *targetShape, strides);
+      Value slicedOperand =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, reductionOp.getVector(), offsets, *targetShape, strides);
       Operation *newOp = cloneOpWithOperandsAndTypes(
           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
       Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
         permutedShape[indices.value()] = (*targetShape)[indices.index()];
       }
-      Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, transposeOp.getVector(), permutedOffsets, permutedShape,
-          strides);
-      Value transposedSlice =
-          rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      Value slicedOperand =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, transposeOp.getVector(), permutedOffsets, permutedShape,
+              strides);
+      Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
+          loc, slicedOperand, permutation);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, transposedSlice, result, elementOffsets, strides);
     }
     rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
       // To get the unrolled gather, extract the same slice based on the
       // decomposed shape from each of the index, mask, and pass-through
       // vectors.
-      Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
-      Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+      Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
-      Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
+      Value passThruSubVec =
+          rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+              loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
+              strides);
       auto slicedGather = rewriter.create<vector::GatherOp>(
           loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
           indexSubVec, maskSubVec, passThruSubVec);
 
-      result = rewriter.create<vector::InsertStridedSliceOp>(
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
           loc, slicedGather, result, elementOffsets, strides);
     }
     rewriter.replaceOp(gatherOp, result);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61e858f5f226a13..d016c2efa142628 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,28 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
 
 // -----
 
+// CHECK-LABEL: @extract_scalar_poison
+func.func @extract_scalar_poison() -> f32 {
+  // CHECK-NEXT: ub.poison : f32
+  //  CHECK-NOT: vector.extract
+  %0 = ub.poison : vector<4x8xf32>
+  %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison
+func.func @extract_vector_poison() -> vector<8xf32> {
+  // CHECK-NEXT: ub.poison : vector<8xf32>
+  //  CHECK-NOT: vector.extract
+  %0 = ub.poison : vector<4x8xf32>
+  %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
+  return %1 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_scalar_poison_idx
 func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
   //  CHECK-NOT: vector.extract
@@ -2886,6 +2908,37 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
   return %1 : vector<4xi8>
 }
 
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_scalar_poison
+func.func @insert_scalar_poison(%a: vector<4x8xf32>)
+    -> vector<4x8xf32> {
+  //  CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
+  //  CHECK-NEXT: vector.insert %[[UB]]
+  %0 = ub.poison : f32
+  %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
+  return %1 : vector<4x8xf32>
+}
+
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_vector_poison
+func.func @insert_vector_poison(%a: vector<4x8xf32>)
+    -> vector<4x8xf32> {
+  //  CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
+  //  CHECK-NEXT: vector.insert %[[UB]]
+  %0 = ub.poison : vector<8xf32>
+  %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
+  return %1 : vector<4x8xf32>
+}
+
+
 // -----
 
 // CHECK-LABEL: @insert_scalar_poison_idx

// CHECK-LABEL: @extract_scalar_poison
func.func @extract_scalar_poison() -> f32 {
// CHECK-NEXT: ub.poison : f32
// CHECK-NOT: vector.extract
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking that there's no extract, consider checking that poison was returned? That would make the check more local (you don't have to scan all the way down to the next label).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find CHECK-NOT very helpful (in terms of testing and documenting). This can be re-written as:

Suggested change
// CHECK-NOT: vector.extract
// CHECK-NOT: vector.extract
// CHECK-NEXT: %[[RET:.*]] = ub.poison : f32
// CHECK-NEXT return [[RET]]

// CHECK-LABEL: @extract_vector_poison
func.func @extract_vector_poison() -> vector<8xf32> {
// CHECK-NEXT: ub.poison : vector<8xf32>
// CHECK-NOT: vector.extract
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here

func.func @insert_scalar_poison(%a: vector<4x8xf32>)
-> vector<4x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
// CHECK-NEXT: vector.insert %[[UB]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check that this is returned?

func.func @insert_vector_poison(%a: vector<4x8xf32>)
-> vector<4x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
// CHECK-NEXT: vector.insert %[[UB]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here

@kuhar kuhar requested a review from andfau-amd February 7, 2025 16:01
Copy link
Contributor

@andfau-amd andfau-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't poison folding for ExtractStridedSliceOp, right? Should there be?

if (!llvm::is_contained(staticPos, poisonVal))
return ub::PoisonAttr();
return {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to say thanks for this change because the old style made this code rather confusing when I was trying to debug poison folding the other day. This new style is a lot clearer. :)

@dcaballe
Copy link
Contributor Author

dcaballe commented Feb 7, 2025

There isn't poison folding for ExtractStridedSliceOp, right? Should there be?

I'm adding support incrementally to these ops since having poison working end-to-end has proven to be quite involved even for a few ops. We'll get there :)

@dcaballe dcaballe merged commit 6832514 into llvm:main Feb 7, 2025
5 of 6 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`. It
also replaces `create` with `createOrFold` insert/extract ops in vector
unroll and transpose lowering patterns to trigger the poison foldings
introduced recently.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants