@@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
334
334
// /
335
335
// / Result:
336
336
// / linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
337
- static void atomicRMWStore (OpBuilder &builder, Location loc,
338
- MemRefValue linearizedMemref, Value storeIdx,
339
- VectorValue valueToStore, Value mask) {
337
+ static void atomicRMW (OpBuilder &builder, Location loc,
338
+ MemRefValue linearizedMemref, Value storeIdx,
339
+ VectorValue valueToStore, Value mask) {
340
340
assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
341
341
342
342
// Create an atomic load-modify-write region using
@@ -363,12 +363,11 @@ static void atomicRMWStore(OpBuilder &builder, Location loc,
363
363
builder.create <memref::AtomicYieldOp>(loc, scalarMaskedValue);
364
364
}
365
365
366
- // / Generate a non-atomic read-modify-write sequence for subbyte storing.
367
- // / It has similar logic to `atomicRMWStore`, but without atomicity.
368
- static void nonAtomicRMWStore (OpBuilder &builder, Location loc,
369
- MemRefValue linearizedMemref,
370
- Value linearizedIndex, VectorValue valueToStore,
371
- Value mask) {
366
+ // / Generate a non-atomic read-modify-write sequence for storing to the emulated
367
+ // / type. It has similar logic to `atomicRMWStore`, but without atomicity.
368
+ static void nonAtomicRMW (OpBuilder &builder, Location loc,
369
+ MemRefValue linearizedMemref, Value linearizedIndex,
370
+ VectorValue valueToStore, Value mask) {
372
371
assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
373
372
374
373
auto oneElemVecType =
@@ -427,9 +426,9 @@ namespace {
427
426
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
428
427
using OpConversionPattern::OpConversionPattern;
429
428
430
- ConvertVectorStore (MLIRContext *context, bool useAtomicWrites )
429
+ ConvertVectorStore (MLIRContext *context, bool disableAtomicRMW )
431
430
: OpConversionPattern<vector::StoreOp>(context),
432
- useAtomicWrites_ (useAtomicWrites ) {}
431
+ disableAtomicRMW (disableAtomicRMW ) {}
433
432
434
433
LogicalResult
435
434
matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
@@ -557,6 +556,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
557
556
auto subWidthStoreMaskType =
558
557
VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
559
558
559
+ auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
560
+
560
561
// 1. Partial width store for the leading byte.
561
562
// When the store address is not aligned to emulated width boundary, deal
562
563
// with the unaligned part so that the rest elements are aligned to width
@@ -581,8 +582,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
581
582
extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
582
583
frontSubWidthStoreElem, *foldedNumFrontPadElems);
583
584
584
- auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore;
585
-
586
585
storeFunc (rewriter, loc, memrefBase, currentDestIndex,
587
586
cast<VectorValue>(value), frontMask.getResult ());
588
587
}
@@ -639,17 +638,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
639
638
auto backMask = rewriter.create <arith::ConstantOp>(
640
639
loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
641
640
642
- subEmulatedWidthStore (rewriter, loc, memrefBase, currentDestIndex,
643
- cast<VectorValue>(subWidthStorePart),
644
- backMask.getResult ());
641
+ storeFunc (rewriter, loc, memrefBase, currentDestIndex,
642
+ cast<VectorValue>(subWidthStorePart), backMask.getResult ());
645
643
}
646
644
647
645
rewriter.eraseOp (op);
648
646
return success ();
649
647
}
650
648
651
649
private:
652
- const bool useAtomicWrites_ ;
650
+ const bool disableAtomicRMW ;
653
651
};
654
652
655
653
// ===----------------------------------------------------------------------===//
@@ -1962,7 +1960,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1962
1960
1963
1961
void vector::populateVectorNarrowTypeEmulationPatterns (
1964
1962
const arith::NarrowTypeEmulationConverter &typeConverter,
1965
- RewritePatternSet &patterns, bool useAtomicWrites ) {
1963
+ RewritePatternSet &patterns, bool disableAtomicRMW ) {
1966
1964
1967
1965
// Populate `vector.*` conversion patterns.
1968
1966
// TODO: #119553 support atomicity
@@ -1973,7 +1971,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
1973
1971
// Populate `vector.*` store conversion patterns. The caller can choose
1974
1972
// to avoid emitting atomic operations and reduce it to load-modify-write
1975
1973
// sequence for stores if it is known there are no thread contentions.
1976
- patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites );
1974
+ patterns.insert <ConvertVectorStore>(patterns.getContext (), disableAtomicRMW );
1977
1975
}
1978
1976
1979
1977
void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments