Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Fold vector.extract from poison vector #126122

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {

/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
static ub::PoisonAttr
foldPoisonIndexInsertExtractOp(MLIRContext *context,
ArrayRef<int64_t> staticPos, int64_t poisonVal) {
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
ArrayRef<int64_t> staticPos,
int64_t poisonVal) {
if (!llvm::is_contained(staticPos, poisonVal))
return ub::PoisonAttr();
return {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to say thanks for this change because the old style made this code rather confusing when I was trying to debug poison folding the other day. This new style is a lot clearer. :)


return ub::PoisonAttr::get(context);
}

/// Fold a vector extract from is a poison source.
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
return srcAttr;

return {};
}

OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
Expand All @@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
ImplicitLocOpBuilder b(source.getLoc(), builder);
SmallVector<Value> vs;
for (int64_t i = 0; i < m; ++i)
vs.push_back(b.create<vector::ExtractOp>(source, i));
vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));

// Interleave 32-bit lanes using
// 8x _mm512_unpacklo_epi32
Expand Down Expand Up @@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
result =
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
insertIdxs);
}

rewriter.replaceOp(op, result);
Expand Down
55 changes: 31 additions & 24 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());

result = rewriter.create<vector::InsertStridedSliceOp>(
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
Expand Down Expand Up @@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
Value resultTensor;
for (SmallVector<int64_t> elementOffsets :
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value> indices =
sliceTransferIndices(elementOffsets, originalIndices,
Expand Down Expand Up @@ -289,8 +289,9 @@ struct UnrollContractionPattern
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
loc, operand, operandOffets, operandShape, operandStrides);
slicesOperands[index] =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand, operandOffets, operandShape, operandStrides);
};

// Extract the new lhs operand.
Expand Down Expand Up @@ -333,7 +334,7 @@ struct UnrollContractionPattern
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.create<vector::InsertStridedSliceOp>(
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
Expand Down Expand Up @@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<Value> operands;
SmallVector<int64_t> operandStrides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getSource(), offsets, *targetShape,
operandStrides);
operands.push_back(slicedOperand);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
Expand All @@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
if (accIt != accCache.end())
acc = accIt->second;
else
acc = rewriter.create<vector::ExtractStridedSliceOp>(
acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
operands.push_back(acc);
auto targetType = VectorType::get(
Expand All @@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.create<vector::InsertStridedSliceOp>(
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
Expand Down Expand Up @@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
continue;
}
extractOperands.push_back(
rewriter.create<vector::ExtractStridedSliceOp>(
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
result = rewriter.create<vector::InsertStridedSliceOp>(
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalSize, *targetShape)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionOp.getVector(), offsets, *targetShape, strides);
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
Expand Down Expand Up @@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
strides);
Value transposedSlice =
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
result = rewriter.create<vector::InsertStridedSliceOp>(
Value slicedOperand =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
strides);
Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
loc, slicedOperand, permutation);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, transposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(transposeOp, result);
Expand Down Expand Up @@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// To get the unrolled gather, extract the same slice based on the
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = rewriter.create<vector::GatherOp>(
loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
indexSubVec, maskSubVec, passThruSubVec);

result = rewriter.create<vector::InsertStridedSliceOp>(
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, slicedGather, result, elementOffsets, strides);
}
rewriter.replaceOp(gatherOp, result);
Expand Down
75 changes: 69 additions & 6 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,35 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index

// -----

// CHECK-LABEL: @extract_scalar_poison
func.func @extract_scalar_poison() -> f32 {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
// CHECK-NOT: vector.extract
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking that there's no extract, consider checking that poison was returned? That would make the check more local (you don't have to scan all the way down to the next label).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find CHECK-NOT very helpful (in terms of testing and documenting). This can be re-written as:

Suggested change
// CHECK-NOT: vector.extract
// CHECK-NOT: vector.extract
// CHECK-NEXT: %[[RET:.*]] = ub.poison : f32
// CHECK-NEXT return [[RET]]

// CHECK-NEXT: return %[[UB]] : f32
%0 = ub.poison : vector<4x8xf32>
%1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
return %1 : f32
}

// -----

// CHECK-LABEL: @extract_vector_poison
func.func @extract_vector_poison() -> vector<8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
// CHECK-NOT: vector.extract
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here

// CHECK-NEXT: return %[[UB]] : vector<8xf32>
%0 = ub.poison : vector<4x8xf32>
%1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
return %1 : vector<8xf32>
}

// -----

// CHECK-LABEL: @extract_scalar_poison_idx
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
// CHECK-NOT: vector.extract
// CHECK-NEXT: ub.poison : f32
// CHECK-NEXT: return %[[UB]] : f32
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
return %0 : f32
}
Expand All @@ -144,8 +169,9 @@ func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {

// CHECK-LABEL: @extract_vector_poison_idx
func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<5xf32>
// CHECK-NOT: vector.extract
// CHECK-NEXT: ub.poison : vector<5xf32>
// CHECK-NEXT: return %[[UB]] : vector<5xf32>
%0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
return %0 : vector<5xf32>
}
Expand All @@ -155,8 +181,9 @@ func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
// CHECK-LABEL: @extract_multiple_poison_idx
func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>)
-> vector<8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
// CHECK-NOT: vector.extract
// CHECK-NEXT: ub.poison : vector<8xf32>
// CHECK-NEXT: return %[[UB]] : vector<8xf32>
%0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32>
return %0 : vector<8xf32>
}
Expand Down Expand Up @@ -2886,13 +2913,47 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
return %1 : vector<4xi8>
}

// -----

// Insert a poison value shouldn't be folded as the resulting vector is not
// fully poison.

// CHECK-LABEL: @insert_scalar_poison
func.func @insert_scalar_poison(%a: vector<4x8xf32>)
-> vector<4x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]]
// CHECK-NEXT: return %[[RES]] : vector<4x8xf32>
%0 = ub.poison : f32
%1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
return %1 : vector<4x8xf32>
}

// -----

// Insert a poison value shouldn't be folded as the resulting vector is not
// fully poison.

// CHECK-LABEL: @insert_vector_poison
func.func @insert_vector_poison(%a: vector<4x8xf32>)
-> vector<4x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]]
// CHECK-NEXT: return %[[RES]] : vector<4x8xf32>
%0 = ub.poison : vector<8xf32>
%1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
return %1 : vector<4x8xf32>
}


// -----

// CHECK-LABEL: @insert_scalar_poison_idx
func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
-> vector<4x5xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
// CHECK-NOT: vector.insert
// CHECK-NEXT: ub.poison : vector<4x5xf32>
// CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
%0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
return %0 : vector<4x5xf32>
}
Expand All @@ -2902,8 +2963,9 @@ func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
// CHECK-LABEL: @insert_vector_poison_idx
func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
-> vector<4x5xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
// CHECK-NOT: vector.insert
// CHECK-NEXT: ub.poison : vector<4x5xf32>
// CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
%0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
return %0 : vector<4x5xf32>
}
Expand All @@ -2913,8 +2975,9 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
// CHECK-LABEL: @insert_multiple_poison_idx
func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
-> vector<4x5x8xf32> {
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5x8xf32>
// CHECK-NOT: vector.insert
// CHECK-NEXT: ub.poison : vector<4x5x8xf32>
// CHECK-NEXT: return %[[UB]] : vector<4x5x8xf32>
%0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32>
return %0 : vector<4x5x8xf32>
}
Expand Down