Skip to content

Commit a5fb2e9

Browse files
dcaballeIcohedron
authored andcommitted
[mlir][Vector] Fold vector.extract from poison vector (llvm#126122)
This PR adds a folder for `vector.extract(ub.poison) -> ub.poison`. It also replaces `create` with `createOrFold` insert/extract ops in vector unroll and transpose lowering patterns to trigger the poison foldings introduced recently.
1 parent a5eb2e7 commit a5fb2e9

File tree

4 files changed

+118
-38
lines changed

4 files changed

+118
-38
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -1991,15 +1991,23 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19911991

19921992
/// Fold an insert or extract operation into an poison value when a poison index
19931993
/// is found at any dimension of the static position.
1994-
static ub::PoisonAttr
1995-
foldPoisonIndexInsertExtractOp(MLIRContext *context,
1996-
ArrayRef<int64_t> staticPos, int64_t poisonVal) {
1994+
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context,
1995+
ArrayRef<int64_t> staticPos,
1996+
int64_t poisonVal) {
19971997
if (!llvm::is_contained(staticPos, poisonVal))
1998-
return ub::PoisonAttr();
1998+
return {};
19991999

20002000
return ub::PoisonAttr::get(context);
20012001
}
20022002

2003+
/// Fold a vector extract from is a poison source.
2004+
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
2005+
if (llvm::isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2006+
return srcAttr;
2007+
2008+
return {};
2009+
}
2010+
20032011
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20042012
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
20052013
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2009,6 +2017,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
20092017
if (auto res = foldPoisonIndexInsertExtractOp(
20102018
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
20112019
return res;
2020+
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2021+
return res;
20122022
if (succeeded(foldExtractOpFromExtractChain(*this)))
20132023
return getResult();
20142024
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
209209
ImplicitLocOpBuilder b(source.getLoc(), builder);
210210
SmallVector<Value> vs;
211211
for (int64_t i = 0; i < m; ++i)
212-
vs.push_back(b.create<vector::ExtractOp>(source, i));
212+
vs.push_back(b.createOrFold<vector::ExtractOp>(source, i));
213213

214214
// Interleave 32-bit lanes using
215215
// 8x _mm512_unpacklo_epi32
@@ -378,9 +378,9 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
378378
SmallVector<int64_t> insertIdxs(extractIdxs);
379379
applyPermutationToVector(insertIdxs, prunedTransp);
380380
Value extractOp =
381-
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
382-
result =
383-
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
381+
rewriter.createOrFold<vector::ExtractOp>(loc, input, extractIdxs);
382+
result = rewriter.createOrFold<vector::InsertOp>(loc, extractOp, result,
383+
insertIdxs);
384384
}
385385

386386
rewriter.replaceOp(op, result);

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

+31-24
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
172172
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
173173
readOp.getInBoundsAttr());
174174

175-
result = rewriter.create<vector::InsertStridedSliceOp>(
175+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
176176
loc, slicedRead, result, elementOffsets, strides);
177177
}
178178
rewriter.replaceOp(readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
213213
Value resultTensor;
214214
for (SmallVector<int64_t> elementOffsets :
215215
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
216-
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
216+
Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
217217
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
218218
SmallVector<Value> indices =
219219
sliceTransferIndices(elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
289289
SmallVector<int64_t> operandShape = applyPermutationMap(
290290
permutationMap, ArrayRef<int64_t>(*targetShape));
291291
SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
292-
slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
293-
loc, operand, operandOffets, operandShape, operandStrides);
292+
slicesOperands[index] =
293+
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
294+
loc, operand, operandOffets, operandShape, operandStrides);
294295
};
295296

296297
// Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
333334
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
334335
for (const auto &it : accCache) {
335336
SmallVector<int64_t> dstStrides(it.first.size(), 1);
336-
result = rewriter.create<vector::InsertStridedSliceOp>(
337+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
337338
loc, it.second, result, it.first, dstStrides);
338339
}
339340
rewriter.replaceOp(contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
371372
StaticTileOffsetRange(originalSize, *targetShape)) {
372373
SmallVector<Value> operands;
373374
SmallVector<int64_t> operandStrides(offsets.size(), 1);
374-
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
375-
loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
375+
Value slicedOperand =
376+
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
377+
loc, reductionOp.getSource(), offsets, *targetShape,
378+
operandStrides);
376379
operands.push_back(slicedOperand);
377380
SmallVector<int64_t> dstShape;
378381
SmallVector<int64_t> destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
390393
if (accIt != accCache.end())
391394
acc = accIt->second;
392395
else
393-
acc = rewriter.create<vector::ExtractStridedSliceOp>(
396+
acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
394397
loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
395398
operands.push_back(acc);
396399
auto targetType = VectorType::get(
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
406409
rewriter.getZeroAttr(reductionOp.getDestType()));
407410
for (const auto &it : accCache) {
408411
SmallVector<int64_t> dstStrides(it.first.size(), 1);
409-
result = rewriter.create<vector::InsertStridedSliceOp>(
412+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
410413
loc, it.second, result, it.first, dstStrides);
411414
}
412415
rewriter.replaceOp(reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
453456
continue;
454457
}
455458
extractOperands.push_back(
456-
rewriter.create<vector::ExtractStridedSliceOp>(
459+
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
457460
loc, operand.get(), offsets, *targetShape, strides));
458461
}
459462
Operation *newOp = cloneOpWithOperandsAndTypes(
460463
rewriter, loc, op, extractOperands, newVecType);
461-
result = rewriter.create<vector::InsertStridedSliceOp>(
464+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
462465
loc, newOp->getResult(0), result, offsets, strides);
463466
}
464467
rewriter.replaceOp(op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
490493
for (SmallVector<int64_t> offsets :
491494
StaticTileOffsetRange(originalSize, *targetShape)) {
492495
SmallVector<int64_t> strides(offsets.size(), 1);
493-
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
494-
loc, reductionOp.getVector(), offsets, *targetShape, strides);
496+
Value slicedOperand =
497+
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
498+
loc, reductionOp.getVector(), offsets, *targetShape, strides);
495499
Operation *newOp = cloneOpWithOperandsAndTypes(
496500
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
497501
Value result = newOp->getResult(0);
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
548552
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
549553
permutedShape[indices.value()] = (*targetShape)[indices.index()];
550554
}
551-
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
552-
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
553-
strides);
554-
Value transposedSlice =
555-
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
556-
result = rewriter.create<vector::InsertStridedSliceOp>(
555+
Value slicedOperand =
556+
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
557+
loc, transposeOp.getVector(), permutedOffsets, permutedShape,
558+
strides);
559+
Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
560+
loc, slicedOperand, permutation);
561+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
557562
loc, transposedSlice, result, elementOffsets, strides);
558563
}
559564
rewriter.replaceOp(transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
596601
// To get the unrolled gather, extract the same slice based on the
597602
// decomposed shape from each of the index, mask, and pass-through
598603
// vectors.
599-
Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
604+
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
600605
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
601-
Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
606+
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
602607
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
603-
Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
604-
loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
608+
Value passThruSubVec =
609+
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
610+
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
611+
strides);
605612
auto slicedGather = rewriter.create<vector::GatherOp>(
606613
loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
607614
indexSubVec, maskSubVec, passThruSubVec);
608615

609-
result = rewriter.create<vector::InsertStridedSliceOp>(
616+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
610617
loc, slicedGather, result, elementOffsets, strides);
611618
}
612619
rewriter.replaceOp(gatherOp, result);

mlir/test/Dialect/Vector/canonicalize.mlir

+69-6
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,35 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index
132132

133133
// -----
134134

135+
// CHECK-LABEL: @extract_scalar_poison
136+
func.func @extract_scalar_poison() -> f32 {
137+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
138+
// CHECK-NOT: vector.extract
139+
// CHECK-NEXT: return %[[UB]] : f32
140+
%0 = ub.poison : vector<4x8xf32>
141+
%1 = vector.extract %0[2, 4] : f32 from vector<4x8xf32>
142+
return %1 : f32
143+
}
144+
145+
// -----
146+
147+
// CHECK-LABEL: @extract_vector_poison
148+
func.func @extract_vector_poison() -> vector<8xf32> {
149+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
150+
// CHECK-NOT: vector.extract
151+
// CHECK-NEXT: return %[[UB]] : vector<8xf32>
152+
%0 = ub.poison : vector<4x8xf32>
153+
%1 = vector.extract %0[2] : vector<8xf32> from vector<4x8xf32>
154+
return %1 : vector<8xf32>
155+
}
156+
157+
// -----
158+
135159
// CHECK-LABEL: @extract_scalar_poison_idx
136160
func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
161+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
137162
// CHECK-NOT: vector.extract
138-
// CHECK-NEXT: ub.poison : f32
163+
// CHECK-NEXT: return %[[UB]] : f32
139164
%0 = vector.extract %a[-1, 0] : f32 from vector<4x5xf32>
140165
return %0 : f32
141166
}
@@ -144,8 +169,9 @@ func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
144169

145170
// CHECK-LABEL: @extract_vector_poison_idx
146171
func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
172+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<5xf32>
147173
// CHECK-NOT: vector.extract
148-
// CHECK-NEXT: ub.poison : vector<5xf32>
174+
// CHECK-NEXT: return %[[UB]] : vector<5xf32>
149175
%0 = vector.extract %a[-1] : vector<5xf32> from vector<4x5xf32>
150176
return %0 : vector<5xf32>
151177
}
@@ -155,8 +181,9 @@ func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
155181
// CHECK-LABEL: @extract_multiple_poison_idx
156182
func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>)
157183
-> vector<8xf32> {
184+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
158185
// CHECK-NOT: vector.extract
159-
// CHECK-NEXT: ub.poison : vector<8xf32>
186+
// CHECK-NEXT: return %[[UB]] : vector<8xf32>
160187
%0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32>
161188
return %0 : vector<8xf32>
162189
}
@@ -2886,13 +2913,47 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
28862913
return %1 : vector<4xi8>
28872914
}
28882915

2916+
// -----
2917+
2918+
// Insert a poison value shouldn't be folded as the resulting vector is not
2919+
// fully poison.
2920+
2921+
// CHECK-LABEL: @insert_scalar_poison
2922+
func.func @insert_scalar_poison(%a: vector<4x8xf32>)
2923+
-> vector<4x8xf32> {
2924+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
2925+
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]]
2926+
// CHECK-NEXT: return %[[RES]] : vector<4x8xf32>
2927+
%0 = ub.poison : f32
2928+
%1 = vector.insert %0, %a[2, 3] : f32 into vector<4x8xf32>
2929+
return %1 : vector<4x8xf32>
2930+
}
2931+
2932+
// -----
2933+
2934+
// Insert a poison value shouldn't be folded as the resulting vector is not
2935+
// fully poison.
2936+
2937+
// CHECK-LABEL: @insert_vector_poison
2938+
func.func @insert_vector_poison(%a: vector<4x8xf32>)
2939+
-> vector<4x8xf32> {
2940+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<8xf32>
2941+
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[UB]]
2942+
// CHECK-NEXT: return %[[RES]] : vector<4x8xf32>
2943+
%0 = ub.poison : vector<8xf32>
2944+
%1 = vector.insert %0, %a[2] : vector<8xf32> into vector<4x8xf32>
2945+
return %1 : vector<4x8xf32>
2946+
}
2947+
2948+
28892949
// -----
28902950

28912951
// CHECK-LABEL: @insert_scalar_poison_idx
28922952
func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
28932953
-> vector<4x5xf32> {
2954+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
28942955
// CHECK-NOT: vector.insert
2895-
// CHECK-NEXT: ub.poison : vector<4x5xf32>
2956+
// CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
28962957
%0 = vector.insert %b, %a[-1, 0] : f32 into vector<4x5xf32>
28972958
return %0 : vector<4x5xf32>
28982959
}
@@ -2902,8 +2963,9 @@ func.func @insert_scalar_poison_idx(%a: vector<4x5xf32>, %b: f32)
29022963
// CHECK-LABEL: @insert_vector_poison_idx
29032964
func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
29042965
-> vector<4x5xf32> {
2966+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
29052967
// CHECK-NOT: vector.insert
2906-
// CHECK-NEXT: ub.poison : vector<4x5xf32>
2968+
// CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
29072969
%0 = vector.insert %b, %a[-1] : vector<5xf32> into vector<4x5xf32>
29082970
return %0 : vector<4x5xf32>
29092971
}
@@ -2913,8 +2975,9 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
29132975
// CHECK-LABEL: @insert_multiple_poison_idx
29142976
func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
29152977
-> vector<4x5x8xf32> {
2978+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5x8xf32>
29162979
// CHECK-NOT: vector.insert
2917-
// CHECK-NEXT: ub.poison : vector<4x5x8xf32>
2980+
// CHECK-NEXT: return %[[UB]] : vector<4x5x8xf32>
29182981
%0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32>
29192982
return %0 : vector<4x5x8xf32>
29202983
}

0 commit comments

Comments
 (0)