@@ -172,7 +172,7 @@ struct UnrollTransferReadPattern
172
172
readOp.getPermutationMapAttr (), readOp.getPadding (), readOp.getMask (),
173
173
readOp.getInBoundsAttr ());
174
174
175
- result = rewriter.create <vector::InsertStridedSliceOp>(
175
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
176
176
loc, slicedRead, result, elementOffsets, strides);
177
177
}
178
178
rewriter.replaceOp (readOp, result);
@@ -213,7 +213,7 @@ struct UnrollTransferWritePattern
213
213
Value resultTensor;
214
214
for (SmallVector<int64_t > elementOffsets :
215
215
StaticTileOffsetRange (originalSize, *targetShape, loopOrder)) {
216
- Value slicedVector = rewriter.create <vector::ExtractStridedSliceOp>(
216
+ Value slicedVector = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
217
217
loc, writeOp.getVector (), elementOffsets, *targetShape, strides);
218
218
SmallVector<Value> indices =
219
219
sliceTransferIndices (elementOffsets, originalIndices,
@@ -289,8 +289,9 @@ struct UnrollContractionPattern
289
289
SmallVector<int64_t > operandShape = applyPermutationMap (
290
290
permutationMap, ArrayRef<int64_t >(*targetShape));
291
291
SmallVector<int64_t > operandStrides (operandOffets.size (), 1 );
292
- slicesOperands[index ] = rewriter.create <vector::ExtractStridedSliceOp>(
293
- loc, operand, operandOffets, operandShape, operandStrides);
292
+ slicesOperands[index ] =
293
+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
294
+ loc, operand, operandOffets, operandShape, operandStrides);
294
295
};
295
296
296
297
// Extract the new lhs operand.
@@ -333,7 +334,7 @@ struct UnrollContractionPattern
333
334
loc, dstVecType, rewriter.getZeroAttr (dstVecType));
334
335
for (const auto &it : accCache) {
335
336
SmallVector<int64_t > dstStrides (it.first .size (), 1 );
336
- result = rewriter.create <vector::InsertStridedSliceOp>(
337
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
337
338
loc, it.second , result, it.first , dstStrides);
338
339
}
339
340
rewriter.replaceOp (contractOp, result);
@@ -371,8 +372,10 @@ struct UnrollMultiReductionPattern
371
372
StaticTileOffsetRange (originalSize, *targetShape)) {
372
373
SmallVector<Value> operands;
373
374
SmallVector<int64_t > operandStrides (offsets.size (), 1 );
374
- Value slicedOperand = rewriter.create <vector::ExtractStridedSliceOp>(
375
- loc, reductionOp.getSource (), offsets, *targetShape, operandStrides);
375
+ Value slicedOperand =
376
+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
377
+ loc, reductionOp.getSource (), offsets, *targetShape,
378
+ operandStrides);
376
379
operands.push_back (slicedOperand);
377
380
SmallVector<int64_t > dstShape;
378
381
SmallVector<int64_t > destOffset;
@@ -390,7 +393,7 @@ struct UnrollMultiReductionPattern
390
393
if (accIt != accCache.end ())
391
394
acc = accIt->second ;
392
395
else
393
- acc = rewriter.create <vector::ExtractStridedSliceOp>(
396
+ acc = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
394
397
loc, reductionOp.getAcc (), destOffset, dstShape, accStrides);
395
398
operands.push_back (acc);
396
399
auto targetType = VectorType::get (
@@ -406,7 +409,7 @@ struct UnrollMultiReductionPattern
406
409
rewriter.getZeroAttr (reductionOp.getDestType ()));
407
410
for (const auto &it : accCache) {
408
411
SmallVector<int64_t > dstStrides (it.first .size (), 1 );
409
- result = rewriter.create <vector::InsertStridedSliceOp>(
412
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
410
413
loc, it.second , result, it.first , dstStrides);
411
414
}
412
415
rewriter.replaceOp (reductionOp, result);
@@ -453,12 +456,12 @@ struct UnrollElementwisePattern : public RewritePattern {
453
456
continue ;
454
457
}
455
458
extractOperands.push_back (
456
- rewriter.create <vector::ExtractStridedSliceOp>(
459
+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
457
460
loc, operand.get (), offsets, *targetShape, strides));
458
461
}
459
462
Operation *newOp = cloneOpWithOperandsAndTypes (
460
463
rewriter, loc, op, extractOperands, newVecType);
461
- result = rewriter.create <vector::InsertStridedSliceOp>(
464
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
462
465
loc, newOp->getResult (0 ), result, offsets, strides);
463
466
}
464
467
rewriter.replaceOp (op, result);
@@ -490,8 +493,9 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
490
493
for (SmallVector<int64_t > offsets :
491
494
StaticTileOffsetRange (originalSize, *targetShape)) {
492
495
SmallVector<int64_t > strides (offsets.size (), 1 );
493
- Value slicedOperand = rewriter.create <vector::ExtractStridedSliceOp>(
494
- loc, reductionOp.getVector (), offsets, *targetShape, strides);
496
+ Value slicedOperand =
497
+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
498
+ loc, reductionOp.getVector (), offsets, *targetShape, strides);
495
499
Operation *newOp = cloneOpWithOperandsAndTypes (
496
500
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType ());
497
501
Value result = newOp->getResult (0 );
@@ -548,12 +552,13 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
548
552
permutedOffsets[indices.value ()] = elementOffsets[indices.index ()];
549
553
permutedShape[indices.value ()] = (*targetShape)[indices.index ()];
550
554
}
551
- Value slicedOperand = rewriter.create <vector::ExtractStridedSliceOp>(
552
- loc, transposeOp.getVector (), permutedOffsets, permutedShape,
553
- strides);
554
- Value transposedSlice =
555
- rewriter.create <vector::TransposeOp>(loc, slicedOperand, permutation);
556
- result = rewriter.create <vector::InsertStridedSliceOp>(
555
+ Value slicedOperand =
556
+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
557
+ loc, transposeOp.getVector (), permutedOffsets, permutedShape,
558
+ strides);
559
+ Value transposedSlice = rewriter.createOrFold <vector::TransposeOp>(
560
+ loc, slicedOperand, permutation);
561
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
557
562
loc, transposedSlice, result, elementOffsets, strides);
558
563
}
559
564
rewriter.replaceOp (transposeOp, result);
@@ -596,17 +601,19 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
596
601
// To get the unrolled gather, extract the same slice based on the
597
602
// decomposed shape from each of the index, mask, and pass-through
598
603
// vectors.
599
- Value indexSubVec = rewriter.create <vector::ExtractStridedSliceOp>(
604
+ Value indexSubVec = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
600
605
loc, gatherOp.getIndexVec (), elementOffsets, *targetShape, strides);
601
- Value maskSubVec = rewriter.create <vector::ExtractStridedSliceOp>(
606
+ Value maskSubVec = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
602
607
loc, gatherOp.getMask (), elementOffsets, *targetShape, strides);
603
- Value passThruSubVec = rewriter.create <vector::ExtractStridedSliceOp>(
604
- loc, gatherOp.getPassThru (), elementOffsets, *targetShape, strides);
608
+ Value passThruSubVec =
609
+ rewriter.createOrFold <vector::ExtractStridedSliceOp>(
610
+ loc, gatherOp.getPassThru (), elementOffsets, *targetShape,
611
+ strides);
605
612
auto slicedGather = rewriter.create <vector::GatherOp>(
606
613
loc, targetType, gatherOp.getBase (), gatherOp.getIndices (),
607
614
indexSubVec, maskSubVec, passThruSubVec);
608
615
609
- result = rewriter.create <vector::InsertStridedSliceOp>(
616
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
610
617
loc, slicedGather, result, elementOffsets, strides);
611
618
}
612
619
rewriter.replaceOp (gatherOp, result);
0 commit comments