Skip to content

Commit d40b31b

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)
This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear.
1 parent 5edc342 commit d40b31b

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

+44-33
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
282282
OpFoldResult linearizedIndices,
283283
int64_t numEmultedElementsToLoad, Type origElemType,
284284
Type emulatedElemType) {
285-
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
286-
origElemType.getIntOrFloatBitWidth();
285+
auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
286+
origElemType.getIntOrFloatBitWidth();
287287
auto newLoad = rewriter.create<vector::LoadOp>(
288288
loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
289289
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
290290
return rewriter.create<vector::BitCastOp>(
291-
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
291+
loc,
292+
VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
293+
origElemType),
292294
newLoad);
293295
}
294296

@@ -321,7 +323,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
321323
if (newBits % oldBits != 0) {
322324
return rewriter.notifyMatchFailure(op, "unalagined element types");
323325
}
324-
int scale = newBits / oldBits;
326+
int elementsPerContainerType = newBits / oldBits;
325327

326328
// Adjust the number of elements to store when emulating narrow types.
327329
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +339,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
337339
// vector<4xi8>
338340

339341
auto origElements = op.getValueToStore().getType().getNumElements();
340-
if (origElements % scale != 0)
342+
if (origElements % elementsPerContainerType != 0)
341343
return failure();
342344

343345
auto stridedMetadata =
@@ -352,7 +354,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
352354
stridedMetadata.getConstifiedMixedStrides(),
353355
getAsOpFoldResult(adaptor.getIndices()));
354356

355-
auto numElements = origElements / scale;
357+
auto numElements = origElements / elementsPerContainerType;
356358
auto bitCast = rewriter.create<vector::BitCastOp>(
357359
loc, VectorType::get(numElements, newElementType),
358360
op.getValueToStore());
@@ -393,9 +395,9 @@ struct ConvertVectorMaskedStore final
393395
return rewriter.notifyMatchFailure(op, "unalagined element types");
394396
}
395397

396-
int scale = newBits / oldBits;
398+
int elementsPerContainerType = newBits / oldBits;
397399
int origElements = op.getValueToStore().getType().getNumElements();
398-
if (origElements % scale != 0)
400+
if (origElements % elementsPerContainerType != 0)
399401
return failure();
400402

401403
auto stridedMetadata =
@@ -444,12 +446,13 @@ struct ConvertVectorMaskedStore final
444446
//
445447
// FIXME: Make an example based on the comment above work (see #115460 for
446448
// reproducer).
447-
FailureOr<Operation *> newMask =
448-
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
449+
FailureOr<Operation *> newMask = getCompressedMaskOp(
450+
rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
449451
if (failed(newMask))
450452
return failure();
451453

452-
auto numElements = (origElements + scale - 1) / scale;
454+
auto numElements = (origElements + elementsPerContainerType - 1) /
455+
elementsPerContainerType;
453456
auto newType = VectorType::get(numElements, newElementType);
454457
auto passThru = rewriter.create<arith::ConstantOp>(
455458
loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +461,8 @@ struct ConvertVectorMaskedStore final
458461
loc, newType, adaptor.getBase(), linearizedIndices,
459462
newMask.value()->getResult(0), passThru);
460463

461-
auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
464+
auto newBitCastType =
465+
VectorType::get(numElements * elementsPerContainerType, oldElementType);
462466
Value valueToStore =
463467
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
464468
valueToStore = rewriter.create<arith::SelectOp>(
@@ -500,7 +504,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
500504
if (newBits % oldBits != 0) {
501505
return rewriter.notifyMatchFailure(op, "unalagined element types");
502506
}
503-
int scale = newBits / oldBits;
507+
int elementsPerContainerType = newBits / oldBits;
504508

505509
// Adjust the number of elements to load when emulating narrow types,
506510
// and then cast back to the original type with vector.bitcast op.
@@ -532,7 +536,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
532536
// compile time as they must be constants.
533537

534538
auto origElements = op.getVectorType().getNumElements();
535-
bool isUnalignedEmulation = origElements % scale != 0;
539+
bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
536540

537541
auto stridedMetadata =
538542
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -553,9 +557,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
553557
: 0;
554558

555559
// Always load enough elements which can cover the original elements.
556-
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
557-
auto numElements =
558-
llvm::divideCeil(maxintraDataOffset + origElements, scale);
560+
int64_t maxintraDataOffset =
561+
foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
562+
auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
563+
elementsPerContainerType);
559564
Value result =
560565
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
561566
numElements, oldElementType, newElementType);
@@ -603,7 +608,7 @@ struct ConvertVectorMaskedLoad final
603608
if (newBits % oldBits != 0) {
604609
return rewriter.notifyMatchFailure(op, "unalagined element types");
605610
}
606-
int scale = newBits / oldBits;
611+
int elementsPerContainerType = newBits / oldBits;
607612

608613
// Adjust the number of elements to load when emulating narrow types,
609614
// and then cast back to the original type with vector.bitcast op.
@@ -649,7 +654,7 @@ struct ConvertVectorMaskedLoad final
649654
// subvector at the proper offset after bit-casting.
650655
auto origType = op.getVectorType();
651656
auto origElements = origType.getNumElements();
652-
bool isUnalignedEmulation = origElements % scale != 0;
657+
bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
653658

654659
auto stridedMetadata =
655660
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -668,18 +673,21 @@ struct ConvertVectorMaskedLoad final
668673
? getConstantIntValue(linearizedInfo.intraDataOffset)
669674
: 0;
670675

671-
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
672-
FailureOr<Operation *> newMask = getCompressedMaskOp(
673-
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
676+
int64_t maxIntraDataOffset =
677+
foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
678+
FailureOr<Operation *> newMask =
679+
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
680+
elementsPerContainerType, maxIntraDataOffset);
674681
if (failed(newMask))
675682
return failure();
676683

677684
Value passthru = op.getPassThru();
678685

679-
auto numElements =
680-
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
686+
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
687+
elementsPerContainerType);
681688
auto loadType = VectorType::get(numElements, newElementType);
682-
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
689+
auto newBitcastType =
690+
VectorType::get(numElements * elementsPerContainerType, oldElementType);
683691

684692
auto emptyVector = rewriter.create<arith::ConstantOp>(
685693
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +714,8 @@ struct ConvertVectorMaskedLoad final
706714
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
707715

708716
Value mask = op.getMask();
709-
auto newSelectMaskType =
710-
VectorType::get(numElements * scale, rewriter.getI1Type());
717+
auto newSelectMaskType = VectorType::get(
718+
numElements * elementsPerContainerType, rewriter.getI1Type());
711719
// TODO: try to fold if op's mask is constant
712720
auto emptyMask = rewriter.create<arith::ConstantOp>(
713721
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -765,11 +773,11 @@ struct ConvertVectorTransferRead final
765773
if (newBits % oldBits != 0) {
766774
return rewriter.notifyMatchFailure(op, "unalagined element types");
767775
}
768-
int scale = newBits / oldBits;
776+
int elementsPerContainerType = newBits / oldBits;
769777

770778
auto origElements = op.getVectorType().getNumElements();
771779

772-
bool isUnalignedEmulation = origElements % scale != 0;
780+
bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
773781

774782
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
775783
adaptor.getPadding());
@@ -792,17 +800,20 @@ struct ConvertVectorTransferRead final
792800
? getConstantIntValue(linearizedInfo.intraDataOffset)
793801
: 0;
794802

795-
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
796-
auto numElements =
797-
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
803+
int64_t maxIntraDataOffset =
804+
foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
805+
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
806+
elementsPerContainerType);
798807

799808
auto newRead = rewriter.create<vector::TransferReadOp>(
800809
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
801810
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
802811
newPadding);
803812

804813
auto bitCast = rewriter.create<vector::BitCastOp>(
805-
loc, VectorType::get(numElements * scale, oldElementType), newRead);
814+
loc,
815+
VectorType::get(numElements * elementsPerContainerType, oldElementType),
816+
newRead);
806817

807818
Value result = bitCast->getResult(0);
808819
if (!foldedIntraVectorOffset) {

0 commit comments

Comments
 (0)