-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Fold vector.extract
from poison vector
#126122
Conversation
This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`. It also replaces `create` with `createOrFold` insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR adds a folder for Full diff: https://github.com/llvm/llvm-project/pull/126122.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 30ff2df7c38fc34..b4a5461f4405dcf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
-static ub::PoisonAttr
-foldPoisonIndexInsertExtractOp(MLIRContext *context,
- ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
+ ArrayRef<int64_t> staticPos,
+ int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
- return ub::PoisonAttr();
+ return {};
return ub::PoisonAttr::get(context);
}
+/// Fold a vector extract from is a poison source.
+static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
+ if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
+ return srcAttr;
+
+ return {};
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..6135a1290d559f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
ImplicitLocOpBuilder b(source.getLoc(), builder);
SmallVector<Value> vs;
for (int64_t i = 0; i < m; ++i)
- vs.push_back(b.create<vector::ExtractOp>(source, i));
+ vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
// Interleave 32-bit lanes using
// 8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
- rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
- result =
- rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+ rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
+ result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
+ insertIdxs);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 800c1d9fb1dbfd6..c1e3850f05c5ec7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
Value resultTensor;
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
- Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
- slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, operand, operandOffets, operandShape, operandStrides);
+ slicesOperands[index] =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffets, operandShape, operandStrides);
};
// Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> operands;
SmallVector<int64_t> operandStrides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getSource(), offsets, *targetShape,
+ operandStrides);
operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
if (accIt != accCache.end())
acc = accIt->second;
else
- acc = rewriter.create<vector::ExtractStridedSliceOp>(
+ acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
operands.push_back(acc);
auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
continue;
}
extractOperands.push_back(
- rewriter.create<vector::ExtractStridedSliceOp>(
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getVector(), offsets, *targetShape, strides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, transposeOp.getVector(), permutedOffsets, permutedShape,
- strides);
- Value transposedSlice =
- rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, transposeOp.getVector(), permutedOffsets, permutedShape,
+ strides);
+ Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
+ loc, slicedOperand, permutation);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// To get the unrolled gather, extract the same slice based on the
// decomposed shape from each of the index, mask, and pass-through
// vectors.
- Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
- Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
- Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
+ Value passThruSubVec =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
+ strides);
auto slicedGather = rewriter.create<vector::GatherOp>(
loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
indexSubVec, maskSubVec, passThruSubVec);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedGather, result, elementOffsets, strides);
}
rewriter.replaceOp(gatherOp, result);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61e858f5f226a13..d016c2efa142628 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,28 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
// -----
+// CHECK-LABEL: @extract_scalar_poison
+func.func @extract_scalar_poison() -> f32 {
+ // CHECK-NEXT: ub.poison : f32
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison
+func.func @extract_vector_poison() -> vector<8xf32> {
+ // CHECK-NEXT: ub.poison : vector<8xf32>
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
+ return %1 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_scalar_poison_idx
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
// CHECK-NOT: vector.extract
@@ -2886,6 +2908,37 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
return %1 : vector<4xi8>
}
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_scalar_poison
+func.func @insert_scalar_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : f32
+ %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_vector_poison
+func.func @insert_vector_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : vector<8xf32>
+ %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+
// -----
// CHECK-LABEL: @insert_scalar_poison_idx
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds a folder for Full diff: https://github.com/llvm/llvm-project/pull/126122.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 30ff2df7c38fc34..b4a5461f4405dcf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
-static ub::PoisonAttr
-foldPoisonIndexInsertExtractOp(MLIRContext *context,
- ArrayRef<int64_t> staticPos, int64_t poisonVal) {
+static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
+ ArrayRef<int64_t> staticPos,
+ int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
- return ub::PoisonAttr();
+ return {};
return ub::PoisonAttr::get(context);
}
+/// Fold a vector extract from is a poison source.
+static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
+ if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
+ return srcAttr;
+
+ return {};
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
+ if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80f..6135a1290d559f5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
ImplicitLocOpBuilder b(source.getLoc(), builder);
SmallVector<Value> vs;
for (int64_t i = 0; i < m; ++i)
- vs.push_back(b.create<vector::ExtractOp>(source, i));
+ vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
// Interleave 32-bit lanes using
// 8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
- rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
- result =
- rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
+ rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
+ result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
+ insertIdxs);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 800c1d9fb1dbfd6..c1e3850f05c5ec7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
Value resultTensor;
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
- Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
- slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, operand, operandOffets, operandShape, operandStrides);
+ slicesOperands[index] =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffets, operandShape, operandStrides);
};
// Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> operands;
SmallVector<int64_t> operandStrides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getSource(), offsets, *targetShape,
+ operandStrides);
operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
if (accIt != accCache.end())
acc = accIt->second;
else
- acc = rewriter.create<vector::ExtractStridedSliceOp>(
+ acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
operands.push_back(acc);
auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
continue;
}
extractOperands.push_back(
- rewriter.create<vector::ExtractStridedSliceOp>(
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> strides(offsets.size(), 1);
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, reductionOp.getVector(), offsets, *targetShape, strides);
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
- Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, transposeOp.getVector(), permutedOffsets, permutedShape,
- strides);
- Value transposedSlice =
- rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ Value slicedOperand =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, transposeOp.getVector(), permutedOffsets, permutedShape,
+ strides);
+ Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
+ loc, slicedOperand, permutation);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// To get the unrolled gather, extract the same slice based on the
// decomposed shape from each of the index, mask, and pass-through
// vectors.
- Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
- Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
+ Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
- Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
+ Value passThruSubVec =
+ rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
+ strides);
auto slicedGather = rewriter.create<vector::GatherOp>(
loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
indexSubVec, maskSubVec, passThruSubVec);
- result = rewriter.create<vector::InsertStridedSliceOp>(
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedGather, result, elementOffsets, strides);
}
rewriter.replaceOp(gatherOp, result);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61e858f5f226a13..d016c2efa142628 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -132,6 +132,28 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
// -----
+// CHECK-LABEL: @extract_scalar_poison
+func.func @extract_scalar_poison() -> f32 {
+ // CHECK-NEXT: ub.poison : f32
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vector_poison
+func.func @extract_vector_poison() -> vector<8xf32> {
+ // CHECK-NEXT: ub.poison : vector<8xf32>
+ // CHECK-NOT: vector.extract
+ %0 = ub.poison : vector<4x8xf32>
+ %1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
+ return %1 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_scalar_poison_idx
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
// CHECK-NOT: vector.extract
@@ -2886,6 +2908,37 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
return %1 : vector<4xi8>
}
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_scalar_poison
+func.func @insert_scalar_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : f32
+ %1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+// -----
+
+// Insert a poison value shouldn't be folded as the resulting vector is not
+// fully poison.
+
+// CHECK-LABEL: @insert_vector_poison
+func.func @insert_vector_poison(%a: vector<4x8xf32>)
+ -> vector<4x8xf32> {
+ // CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
+ // CHECK-NEXT: vector.insert %[[UB]]
+ %0 = ub.poison : vector<8xf32>
+ %1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
+ return %1 : vector<4x8xf32>
+}
+
+
// -----
// CHECK-LABEL: @insert_scalar_poison_idx
|
// CHECK-LABEL: @extract_scalar_poison | ||
func.func @extract_scalar_poison() -> f32 { | ||
// CHECK-NEXT: ub.poison : f32 | ||
// CHECK-NOT: vector.extract |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking that there's no extract, consider checking that poison was returned? That would make the check more local (you don't have to scan all the way down to the next label).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find CHECK-NOT
very helpful (in terms of testing and documenting). This can be re-written as:
// CHECK-NOT: vector.extract | |
// CHECK-NOT: vector.extract | |
// CHECK-NEXT: %[[RET:.*]] = ub.poison : f32 | |
// CHECK-NEXT return [[RET]] |
// CHECK-LABEL: @extract_vector_poison | ||
func.func @extract_vector_poison() -> vector<8xf32> { | ||
// CHECK-NEXT: ub.poison : vector<8xf32> | ||
// CHECK-NOT: vector.extract |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here
func.func @insert_scalar_poison(%a: vector<4x8xf32>) | ||
-> vector<4x8xf32> { | ||
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32 | ||
// CHECK-NEXT: vector.insert %[[UB]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we check that this is returned?
func.func @insert_vector_poison(%a: vector<4x8xf32>) | ||
-> vector<4x8xf32> { | ||
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32> | ||
// CHECK-NEXT: vector.insert %[[UB]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't poison folding for ExtractStridedSliceOp, right? Should there be?
if (!llvm::is_contained(staticPos, poisonVal)) | ||
return ub::PoisonAttr(); | ||
return {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to say thanks for this change because the old style made this code rather confusing when I was trying to debug poison folding the other day. This new style is a lot clearer. :)
I'm adding support incrementally to these ops since having poison working end-to-end has proven to be quite involved even for a few ops. We'll get there :) |
This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`. It also replaces `create` with `createOrFold` insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.
This PR adds a folder for
vector.extract(ub.poison) -> ub.poison
. It also replacescreate
withcreateOrFold
insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.