Skip to content

Commit aaeb0fb

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)
This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **DEPENDS ON:** * llvm#123526 123526 Please only review the [top commit](llvm@d40b31b). **GitHub issue to track this work**: llvm#123630
1 parent 978310f commit aaeb0fb

File tree

1 file changed

+45
-33
lines changed

1 file changed

+45
-33
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

+45-33
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,15 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
290290
int64_t numContainerElemsToLoad,
291291
Type emulatedElemTy,
292292
Type containerElemTy) {
293-
auto scale = containerElemTy.getIntOrFloatBitWidth() /
294-
emulatedElemTy.getIntOrFloatBitWidth();
293+
auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
294+
emulatedElemTy.getIntOrFloatBitWidth();
295295
auto newLoad = rewriter.create<vector::LoadOp>(
296296
loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
297297
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
298298
return rewriter.create<vector::BitCastOp>(
299-
loc, VectorType::get(numContainerElemsToLoad * scale, emulatedElemTy),
299+
loc,
300+
VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
301+
emulatedElemTy),
300302
newLoad);
301303
}
302304

@@ -388,10 +390,11 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
388390
"sliceNumElements * vector element size must be less than or equal to 8");
389391
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
390392
"vector element must be a valid sub-byte type");
391-
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
393+
auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
392394
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
393-
loc, VectorType::get({scale}, vectorElementType),
394-
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
395+
loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
396+
rewriter.getZeroAttr(
397+
VectorType::get({emulatedPerContainerElem}, vectorElementType)));
395398
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
396399
extractOffset, sliceNumElements);
397400
return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
@@ -656,9 +659,9 @@ struct ConvertVectorMaskedStore final
656659
"(bit-wise misalignment)");
657660
}
658661

659-
int scale = containerBits / emulatedBits;
662+
int emulatedPerContainerElem = containerBits / emulatedBits;
660663
int origElements = op.getValueToStore().getType().getNumElements();
661-
if (origElements % scale != 0)
664+
if (origElements % emulatedPerContainerElem != 0)
662665
return failure();
663666

664667
auto stridedMetadata =
@@ -707,12 +710,13 @@ struct ConvertVectorMaskedStore final
707710
//
708711
// FIXME: Make an example based on the comment above work (see #115460 for
709712
// reproducer).
710-
FailureOr<Operation *> newMask =
711-
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
713+
FailureOr<Operation *> newMask = getCompressedMaskOp(
714+
rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
712715
if (failed(newMask))
713716
return failure();
714717

715-
auto numElements = (origElements + scale - 1) / scale;
718+
auto numElements = (origElements + emulatedPerContainerElem - 1) /
719+
emulatedPerContainerElem;
716720
auto newType = VectorType::get(numElements, containerElemTy);
717721
auto passThru = rewriter.create<arith::ConstantOp>(
718722
loc, newType, rewriter.getZeroAttr(newType));
@@ -721,7 +725,8 @@ struct ConvertVectorMaskedStore final
721725
loc, newType, adaptor.getBase(), linearizedIndices,
722726
newMask.value()->getResult(0), passThru);
723727

724-
auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
728+
auto newBitCastType =
729+
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
725730
Value valueToStore =
726731
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
727732
valueToStore = rewriter.create<arith::SelectOp>(
@@ -765,7 +770,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
765770
op, "impossible to pack emulated elements into container elements "
766771
"(bit-wise misalignment)");
767772
}
768-
int scale = containerBits / emulatedBits;
773+
int emulatedPerContainerElem = containerBits / emulatedBits;
769774

770775
// Adjust the number of elements to load when emulating narrow types,
771776
// and then cast back to the original type with vector.bitcast op.
@@ -797,7 +802,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
797802
// compile time as they must be constants.
798803

799804
auto origElements = op.getVectorType().getNumElements();
800-
bool isAlignedEmulation = origElements % scale == 0;
805+
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
801806

802807
auto stridedMetadata =
803808
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -818,9 +823,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
818823
: getConstantIntValue(linearizedInfo.intraDataOffset);
819824

820825
// Always load enough elements which can cover the original elements.
821-
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
822-
auto numElements =
823-
llvm::divideCeil(maxintraDataOffset + origElements, scale);
826+
int64_t maxintraDataOffset =
827+
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
828+
auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
829+
emulatedPerContainerElem);
824830
Value result =
825831
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
826832
numElements, emulatedElemTy, containerElemTy);
@@ -870,7 +876,7 @@ struct ConvertVectorMaskedLoad final
870876
op, "impossible to pack emulated elements into container elements "
871877
"(bit-wise misalignment)");
872878
}
873-
int scale = containerBits / emulatedBits;
879+
int emulatedPerContainerElem = containerBits / emulatedBits;
874880

875881
// Adjust the number of elements to load when emulating narrow types,
876882
// and then cast back to the original type with vector.bitcast op.
@@ -916,7 +922,7 @@ struct ConvertVectorMaskedLoad final
916922
// subvector at the proper offset after bit-casting.
917923
auto origType = op.getVectorType();
918924
auto origElements = origType.getNumElements();
919-
bool isAlignedEmulation = origElements % scale == 0;
925+
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
920926

921927
auto stridedMetadata =
922928
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -935,18 +941,21 @@ struct ConvertVectorMaskedLoad final
935941
? 0
936942
: getConstantIntValue(linearizedInfo.intraDataOffset);
937943

938-
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
939-
FailureOr<Operation *> newMask = getCompressedMaskOp(
940-
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
944+
int64_t maxIntraDataOffset =
945+
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
946+
FailureOr<Operation *> newMask =
947+
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
948+
emulatedPerContainerElem, maxIntraDataOffset);
941949
if (failed(newMask))
942950
return failure();
943951

944952
Value passthru = op.getPassThru();
945953

946-
auto numElements =
947-
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
954+
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
955+
emulatedPerContainerElem);
948956
auto loadType = VectorType::get(numElements, containerElemTy);
949-
auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
957+
auto newBitcastType =
958+
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
950959

951960
auto emptyVector = rewriter.create<arith::ConstantOp>(
952961
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -973,8 +982,8 @@ struct ConvertVectorMaskedLoad final
973982
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
974983

975984
Value mask = op.getMask();
976-
auto newSelectMaskType =
977-
VectorType::get(numElements * scale, rewriter.getI1Type());
985+
auto newSelectMaskType = VectorType::get(
986+
numElements * emulatedPerContainerElem, rewriter.getI1Type());
978987
// TODO: try to fold if op's mask is constant
979988
auto emptyMask = rewriter.create<arith::ConstantOp>(
980989
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -1033,11 +1042,11 @@ struct ConvertVectorTransferRead final
10331042
op, "impossible to pack emulated elements into container elements "
10341043
"(bit-wise misalignment)");
10351044
}
1036-
int scale = containerBits / emulatedBits;
1045+
int emulatedPerContainerElem = containerBits / emulatedBits;
10371046

10381047
auto origElements = op.getVectorType().getNumElements();
10391048

1040-
bool isAlignedEmulation = origElements % scale == 0;
1049+
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
10411050

10421051
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
10431052
adaptor.getPadding());
@@ -1060,17 +1069,20 @@ struct ConvertVectorTransferRead final
10601069
? 0
10611070
: getConstantIntValue(linearizedInfo.intraDataOffset);
10621071

1063-
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
1064-
auto numElements =
1065-
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
1072+
int64_t maxIntraDataOffset =
1073+
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
1074+
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
1075+
emulatedPerContainerElem);
10661076

10671077
auto newRead = rewriter.create<vector::TransferReadOp>(
10681078
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
10691079
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
10701080
newPadding);
10711081

10721082
auto bitCast = rewriter.create<vector::BitCastOp>(
1073-
loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
1083+
loc,
1084+
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
1085+
newRead);
10741086

10751087
Value result = bitCast->getResult(0);
10761088
if (!foldedIntraVectorOffset) {

0 commit comments

Comments
 (0)