@@ -290,13 +290,15 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
290
290
int64_t numContainerElemsToLoad,
291
291
Type emulatedElemTy,
292
292
Type containerElemTy) {
293
- auto scale = containerElemTy.getIntOrFloatBitWidth () /
294
- emulatedElemTy.getIntOrFloatBitWidth ();
293
+ auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth () /
294
+ emulatedElemTy.getIntOrFloatBitWidth ();
295
295
auto newLoad = rewriter.create <vector::LoadOp>(
296
296
loc, VectorType::get (numContainerElemsToLoad, containerElemTy), base,
297
297
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
298
298
return rewriter.create <vector::BitCastOp>(
299
- loc, VectorType::get (numContainerElemsToLoad * scale, emulatedElemTy),
299
+ loc,
300
+ VectorType::get (numContainerElemsToLoad * emulatedPerContainerElem,
301
+ emulatedElemTy),
300
302
newLoad);
301
303
}
302
304
@@ -388,10 +390,11 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
388
390
" sliceNumElements * vector element size must be less than or equal to 8" );
389
391
assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
390
392
" vector element must be a valid sub-byte type" );
391
- auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
393
+ auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth ();
392
394
auto emptyByteVector = rewriter.create <arith::ConstantOp>(
393
- loc, VectorType::get ({scale}, vectorElementType),
394
- rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
395
+ loc, VectorType::get ({emulatedPerContainerElem}, vectorElementType),
396
+ rewriter.getZeroAttr (
397
+ VectorType::get ({emulatedPerContainerElem}, vectorElementType)));
395
398
auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
396
399
extractOffset, sliceNumElements);
397
400
return staticallyInsertSubvector (rewriter, loc, extracted, emptyByteVector,
@@ -656,9 +659,9 @@ struct ConvertVectorMaskedStore final
656
659
" (bit-wise misalignment)" );
657
660
}
658
661
659
- int scale = containerBits / emulatedBits;
662
+ int emulatedPerContainerElem = containerBits / emulatedBits;
660
663
int origElements = op.getValueToStore ().getType ().getNumElements ();
661
- if (origElements % scale != 0 )
664
+ if (origElements % emulatedPerContainerElem != 0 )
662
665
return failure ();
663
666
664
667
auto stridedMetadata =
@@ -707,12 +710,13 @@ struct ConvertVectorMaskedStore final
707
710
//
708
711
// FIXME: Make an example based on the comment above work (see #115460 for
709
712
// reproducer).
710
- FailureOr<Operation *> newMask =
711
- getCompressedMaskOp ( rewriter, loc, op.getMask (), origElements, scale );
713
+ FailureOr<Operation *> newMask = getCompressedMaskOp (
714
+ rewriter, loc, op.getMask (), origElements, emulatedPerContainerElem );
712
715
if (failed (newMask))
713
716
return failure ();
714
717
715
- auto numElements = (origElements + scale - 1 ) / scale;
718
+ auto numElements = (origElements + emulatedPerContainerElem - 1 ) /
719
+ emulatedPerContainerElem;
716
720
auto newType = VectorType::get (numElements, containerElemTy);
717
721
auto passThru = rewriter.create <arith::ConstantOp>(
718
722
loc, newType, rewriter.getZeroAttr (newType));
@@ -721,7 +725,8 @@ struct ConvertVectorMaskedStore final
721
725
loc, newType, adaptor.getBase (), linearizedIndices,
722
726
newMask.value ()->getResult (0 ), passThru);
723
727
724
- auto newBitCastType = VectorType::get (numElements * scale, emulatedElemTy);
728
+ auto newBitCastType =
729
+ VectorType::get (numElements * emulatedPerContainerElem, emulatedElemTy);
725
730
Value valueToStore =
726
731
rewriter.create <vector::BitCastOp>(loc, newBitCastType, newLoad);
727
732
valueToStore = rewriter.create <arith::SelectOp>(
@@ -765,7 +770,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
765
770
op, " impossible to pack emulated elements into container elements "
766
771
" (bit-wise misalignment)" );
767
772
}
768
- int scale = containerBits / emulatedBits;
773
+ int emulatedPerContainerElem = containerBits / emulatedBits;
769
774
770
775
// Adjust the number of elements to load when emulating narrow types,
771
776
// and then cast back to the original type with vector.bitcast op.
@@ -797,7 +802,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
797
802
// compile time as they must be constants.
798
803
799
804
auto origElements = op.getVectorType ().getNumElements ();
800
- bool isAlignedEmulation = origElements % scale == 0 ;
805
+ bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
801
806
802
807
auto stridedMetadata =
803
808
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -818,9 +823,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
818
823
: getConstantIntValue (linearizedInfo.intraDataOffset );
819
824
820
825
// Always load enough elements which can cover the original elements.
821
- int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
822
- auto numElements =
823
- llvm::divideCeil (maxintraDataOffset + origElements, scale);
826
+ int64_t maxintraDataOffset =
827
+ foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
828
+ auto numElements = llvm::divideCeil (maxintraDataOffset + origElements,
829
+ emulatedPerContainerElem);
824
830
Value result =
825
831
emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
826
832
numElements, emulatedElemTy, containerElemTy);
@@ -870,7 +876,7 @@ struct ConvertVectorMaskedLoad final
870
876
op, " impossible to pack emulated elements into container elements "
871
877
" (bit-wise misalignment)" );
872
878
}
873
- int scale = containerBits / emulatedBits;
879
+ int emulatedPerContainerElem = containerBits / emulatedBits;
874
880
875
881
// Adjust the number of elements to load when emulating narrow types,
876
882
// and then cast back to the original type with vector.bitcast op.
@@ -916,7 +922,7 @@ struct ConvertVectorMaskedLoad final
916
922
// subvector at the proper offset after bit-casting.
917
923
auto origType = op.getVectorType ();
918
924
auto origElements = origType.getNumElements ();
919
- bool isAlignedEmulation = origElements % scale == 0 ;
925
+ bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
920
926
921
927
auto stridedMetadata =
922
928
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -935,18 +941,21 @@ struct ConvertVectorMaskedLoad final
935
941
? 0
936
942
: getConstantIntValue (linearizedInfo.intraDataOffset );
937
943
938
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
939
- FailureOr<Operation *> newMask = getCompressedMaskOp (
940
- rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
944
+ int64_t maxIntraDataOffset =
945
+ foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
946
+ FailureOr<Operation *> newMask =
947
+ getCompressedMaskOp (rewriter, loc, op.getMask (), origElements,
948
+ emulatedPerContainerElem, maxIntraDataOffset);
941
949
if (failed (newMask))
942
950
return failure ();
943
951
944
952
Value passthru = op.getPassThru ();
945
953
946
- auto numElements =
947
- llvm::divideCeil (maxIntraDataOffset + origElements, scale );
954
+ auto numElements = llvm::divideCeil (maxIntraDataOffset + origElements,
955
+ emulatedPerContainerElem );
948
956
auto loadType = VectorType::get (numElements, containerElemTy);
949
- auto newBitcastType = VectorType::get (numElements * scale, emulatedElemTy);
957
+ auto newBitcastType =
958
+ VectorType::get (numElements * emulatedPerContainerElem, emulatedElemTy);
950
959
951
960
auto emptyVector = rewriter.create <arith::ConstantOp>(
952
961
loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
@@ -973,8 +982,8 @@ struct ConvertVectorMaskedLoad final
973
982
rewriter.create <vector::BitCastOp>(loc, newBitcastType, newLoad);
974
983
975
984
Value mask = op.getMask ();
976
- auto newSelectMaskType =
977
- VectorType::get ( numElements * scale , rewriter.getI1Type ());
985
+ auto newSelectMaskType = VectorType::get (
986
+ numElements * emulatedPerContainerElem , rewriter.getI1Type ());
978
987
// TODO: try to fold if op's mask is constant
979
988
auto emptyMask = rewriter.create <arith::ConstantOp>(
980
989
loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
@@ -1033,11 +1042,11 @@ struct ConvertVectorTransferRead final
1033
1042
op, " impossible to pack emulated elements into container elements "
1034
1043
" (bit-wise misalignment)" );
1035
1044
}
1036
- int scale = containerBits / emulatedBits;
1045
+ int emulatedPerContainerElem = containerBits / emulatedBits;
1037
1046
1038
1047
auto origElements = op.getVectorType ().getNumElements ();
1039
1048
1040
- bool isAlignedEmulation = origElements % scale == 0 ;
1049
+ bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
1041
1050
1042
1051
auto newPadding = rewriter.create <arith::ExtUIOp>(loc, containerElemTy,
1043
1052
adaptor.getPadding ());
@@ -1060,17 +1069,20 @@ struct ConvertVectorTransferRead final
1060
1069
? 0
1061
1070
: getConstantIntValue (linearizedInfo.intraDataOffset );
1062
1071
1063
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
1064
- auto numElements =
1065
- llvm::divideCeil (maxIntraDataOffset + origElements, scale);
1072
+ int64_t maxIntraDataOffset =
1073
+ foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
1074
+ auto numElements = llvm::divideCeil (maxIntraDataOffset + origElements,
1075
+ emulatedPerContainerElem);
1066
1076
1067
1077
auto newRead = rewriter.create <vector::TransferReadOp>(
1068
1078
loc, VectorType::get (numElements, containerElemTy), adaptor.getSource (),
1069
1079
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
1070
1080
newPadding);
1071
1081
1072
1082
auto bitCast = rewriter.create <vector::BitCastOp>(
1073
- loc, VectorType::get (numElements * scale, emulatedElemTy), newRead);
1083
+ loc,
1084
+ VectorType::get (numElements * emulatedPerContainerElem, emulatedElemTy),
1085
+ newRead);
1074
1086
1075
1087
Value result = bitCast->getResult (0 );
1076
1088
if (!foldedIntraVectorOffset) {
0 commit comments