@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
282
282
OpFoldResult linearizedIndices,
283
283
int64_t numEmultedElementsToLoad, Type origElemType,
284
284
Type emulatedElemType) {
285
- auto scale = emulatedElemType.getIntOrFloatBitWidth () /
286
- origElemType.getIntOrFloatBitWidth ();
285
+ auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth () /
286
+ origElemType.getIntOrFloatBitWidth ();
287
287
auto newLoad = rewriter.create <vector::LoadOp>(
288
288
loc, VectorType::get (numEmultedElementsToLoad, emulatedElemType), base,
289
289
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
290
290
return rewriter.create <vector::BitCastOp>(
291
- loc, VectorType::get (numEmultedElementsToLoad * scale, origElemType),
291
+ loc,
292
+ VectorType::get (numEmultedElementsToLoad * elementsPerContainerType,
293
+ origElemType),
292
294
newLoad);
293
295
}
294
296
@@ -321,7 +323,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
321
323
if (newBits % oldBits != 0 ) {
322
324
return rewriter.notifyMatchFailure (op, " unalagined element types" );
323
325
}
324
- int scale = newBits / oldBits;
326
+ int elementsPerContainerType = newBits / oldBits;
325
327
326
328
// Adjust the number of elements to store when emulating narrow types.
327
329
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +339,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
337
339
// vector<4xi8>
338
340
339
341
auto origElements = op.getValueToStore ().getType ().getNumElements ();
340
- if (origElements % scale != 0 )
342
+ if (origElements % elementsPerContainerType != 0 )
341
343
return failure ();
342
344
343
345
auto stridedMetadata =
@@ -352,7 +354,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
352
354
stridedMetadata.getConstifiedMixedStrides (),
353
355
getAsOpFoldResult (adaptor.getIndices ()));
354
356
355
- auto numElements = origElements / scale ;
357
+ auto numElements = origElements / elementsPerContainerType ;
356
358
auto bitCast = rewriter.create <vector::BitCastOp>(
357
359
loc, VectorType::get (numElements, newElementType),
358
360
op.getValueToStore ());
@@ -393,9 +395,9 @@ struct ConvertVectorMaskedStore final
393
395
return rewriter.notifyMatchFailure (op, " unalagined element types" );
394
396
}
395
397
396
- int scale = newBits / oldBits;
398
+ int elementsPerContainerType = newBits / oldBits;
397
399
int origElements = op.getValueToStore ().getType ().getNumElements ();
398
- if (origElements % scale != 0 )
400
+ if (origElements % elementsPerContainerType != 0 )
399
401
return failure ();
400
402
401
403
auto stridedMetadata =
@@ -444,12 +446,13 @@ struct ConvertVectorMaskedStore final
444
446
//
445
447
// FIXME: Make an example based on the comment above work (see #115460 for
446
448
// reproducer).
447
- FailureOr<Operation *> newMask =
448
- getCompressedMaskOp ( rewriter, loc, op.getMask (), origElements, scale );
449
+ FailureOr<Operation *> newMask = getCompressedMaskOp (
450
+ rewriter, loc, op.getMask (), origElements, elementsPerContainerType );
449
451
if (failed (newMask))
450
452
return failure ();
451
453
452
- auto numElements = (origElements + scale - 1 ) / scale;
454
+ auto numElements = (origElements + elementsPerContainerType - 1 ) /
455
+ elementsPerContainerType;
453
456
auto newType = VectorType::get (numElements, newElementType);
454
457
auto passThru = rewriter.create <arith::ConstantOp>(
455
458
loc, newType, rewriter.getZeroAttr (newType));
@@ -458,7 +461,8 @@ struct ConvertVectorMaskedStore final
458
461
loc, newType, adaptor.getBase (), linearizedIndices,
459
462
newMask.value ()->getResult (0 ), passThru);
460
463
461
- auto newBitCastType = VectorType::get (numElements * scale, oldElementType);
464
+ auto newBitCastType =
465
+ VectorType::get (numElements * elementsPerContainerType, oldElementType);
462
466
Value valueToStore =
463
467
rewriter.create <vector::BitCastOp>(loc, newBitCastType, newLoad);
464
468
valueToStore = rewriter.create <arith::SelectOp>(
@@ -500,7 +504,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
500
504
if (newBits % oldBits != 0 ) {
501
505
return rewriter.notifyMatchFailure (op, " unalagined element types" );
502
506
}
503
- int scale = newBits / oldBits;
507
+ int elementsPerContainerType = newBits / oldBits;
504
508
505
509
// Adjust the number of elements to load when emulating narrow types,
506
510
// and then cast back to the original type with vector.bitcast op.
@@ -532,7 +536,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
532
536
// compile time as they must be constants.
533
537
534
538
auto origElements = op.getVectorType ().getNumElements ();
535
- bool isUnalignedEmulation = origElements % scale != 0 ;
539
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
536
540
537
541
auto stridedMetadata =
538
542
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -553,9 +557,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
553
557
: 0 ;
554
558
555
559
// Always load enough elements which can cover the original elements.
556
- int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
557
- auto numElements =
558
- llvm::divideCeil (maxintraDataOffset + origElements, scale);
560
+ int64_t maxintraDataOffset =
561
+ foldedIntraVectorOffset.value_or (elementsPerContainerType - 1 );
562
+ auto numElements = llvm::divideCeil (maxintraDataOffset + origElements,
563
+ elementsPerContainerType);
559
564
Value result =
560
565
emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
561
566
numElements, oldElementType, newElementType);
@@ -603,7 +608,7 @@ struct ConvertVectorMaskedLoad final
603
608
if (newBits % oldBits != 0 ) {
604
609
return rewriter.notifyMatchFailure (op, " unalagined element types" );
605
610
}
606
- int scale = newBits / oldBits;
611
+ int elementsPerContainerType = newBits / oldBits;
607
612
608
613
// Adjust the number of elements to load when emulating narrow types,
609
614
// and then cast back to the original type with vector.bitcast op.
@@ -649,7 +654,7 @@ struct ConvertVectorMaskedLoad final
649
654
// subvector at the proper offset after bit-casting.
650
655
auto origType = op.getVectorType ();
651
656
auto origElements = origType.getNumElements ();
652
- bool isUnalignedEmulation = origElements % scale != 0 ;
657
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
653
658
654
659
auto stridedMetadata =
655
660
rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -668,18 +673,21 @@ struct ConvertVectorMaskedLoad final
668
673
? getConstantIntValue (linearizedInfo.intraDataOffset )
669
674
: 0 ;
670
675
671
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
672
- FailureOr<Operation *> newMask = getCompressedMaskOp (
673
- rewriter, loc, op.getMask (), origElements, scale, maxIntraDataOffset);
676
+ int64_t maxIntraDataOffset =
677
+ foldedIntraVectorOffset.value_or (elementsPerContainerType - 1 );
678
+ FailureOr<Operation *> newMask =
679
+ getCompressedMaskOp (rewriter, loc, op.getMask (), origElements,
680
+ elementsPerContainerType, maxIntraDataOffset);
674
681
if (failed (newMask))
675
682
return failure ();
676
683
677
684
Value passthru = op.getPassThru ();
678
685
679
- auto numElements =
680
- llvm::divideCeil (maxIntraDataOffset + origElements, scale );
686
+ auto numElements = llvm::divideCeil (maxIntraDataOffset + origElements,
687
+ elementsPerContainerType );
681
688
auto loadType = VectorType::get (numElements, newElementType);
682
- auto newBitcastType = VectorType::get (numElements * scale, oldElementType);
689
+ auto newBitcastType =
690
+ VectorType::get (numElements * elementsPerContainerType, oldElementType);
683
691
684
692
auto emptyVector = rewriter.create <arith::ConstantOp>(
685
693
loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
@@ -706,8 +714,8 @@ struct ConvertVectorMaskedLoad final
706
714
rewriter.create <vector::BitCastOp>(loc, newBitcastType, newLoad);
707
715
708
716
Value mask = op.getMask ();
709
- auto newSelectMaskType =
710
- VectorType::get ( numElements * scale , rewriter.getI1Type ());
717
+ auto newSelectMaskType = VectorType::get (
718
+ numElements * elementsPerContainerType , rewriter.getI1Type ());
711
719
// TODO: try to fold if op's mask is constant
712
720
auto emptyMask = rewriter.create <arith::ConstantOp>(
713
721
loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
@@ -765,11 +773,11 @@ struct ConvertVectorTransferRead final
765
773
if (newBits % oldBits != 0 ) {
766
774
return rewriter.notifyMatchFailure (op, " unalagined element types" );
767
775
}
768
- int scale = newBits / oldBits;
776
+ int elementsPerContainerType = newBits / oldBits;
769
777
770
778
auto origElements = op.getVectorType ().getNumElements ();
771
779
772
- bool isUnalignedEmulation = origElements % scale != 0 ;
780
+ bool isUnalignedEmulation = origElements % elementsPerContainerType != 0 ;
773
781
774
782
auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
775
783
adaptor.getPadding ());
@@ -792,17 +800,20 @@ struct ConvertVectorTransferRead final
792
800
? getConstantIntValue (linearizedInfo.intraDataOffset )
793
801
: 0 ;
794
802
795
- int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
796
- auto numElements =
797
- llvm::divideCeil (maxIntraDataOffset + origElements, scale);
803
+ int64_t maxIntraDataOffset =
804
+ foldedIntraVectorOffset.value_or (elementsPerContainerType - 1 );
805
+ auto numElements = llvm::divideCeil (maxIntraDataOffset + origElements,
806
+ elementsPerContainerType);
798
807
799
808
auto newRead = rewriter.create <vector::TransferReadOp>(
800
809
loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
801
810
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
802
811
newPadding);
803
812
804
813
auto bitCast = rewriter.create <vector::BitCastOp>(
805
- loc, VectorType::get (numElements * scale, oldElementType), newRead);
814
+ loc,
815
+ VectorType::get (numElements * elementsPerContainerType, oldElementType),
816
+ newRead);
806
817
807
818
Value result = bitCast->getResult (0 );
808
819
if (!foldedIntraVectorOffset) {
0 commit comments