Skip to content

Commit c9e2754

Browse files
committed
Update name
1 parent 562d87e commit c9e2754

File tree

3 files changed

+27
-28
lines changed

3 files changed

+27
-28
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types. The `useAtomicWrites` indicates whether to use
368-
/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
369-
/// rmw sequence otherwise.
367+
/// over wider types. The `disableAtomicRMW` indicates whether to use a normal
368+
/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
369+
/// perform subbyte storing.
370370
void populateVectorNarrowTypeEmulationPatterns(
371371
const arith::NarrowTypeEmulationConverter &typeConverter,
372-
RewritePatternSet &patterns, bool useAtomicWrites = true);
372+
RewritePatternSet &patterns, bool disableAtomicRMW = false);
373373

374374
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
375375
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

+17-19
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
334334
///
335335
/// Result:
336336
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
337-
static void atomicRMWStore(OpBuilder &builder, Location loc,
338-
MemRefValue linearizedMemref, Value storeIdx,
339-
VectorValue valueToStore, Value mask) {
337+
static void atomicRMW(OpBuilder &builder, Location loc,
338+
MemRefValue linearizedMemref, Value storeIdx,
339+
VectorValue valueToStore, Value mask) {
340340
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
341341

342342
// Create an atomic load-modify-write region using
@@ -363,12 +363,11 @@ static void atomicRMWStore(OpBuilder &builder, Location loc,
363363
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
364364
}
365365

366-
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
367-
/// It has similar logic to `atomicRMWStore`, but without atomicity.
368-
static void nonAtomicRMWStore(OpBuilder &builder, Location loc,
369-
MemRefValue linearizedMemref,
370-
Value linearizedIndex, VectorValue valueToStore,
371-
Value mask) {
366+
/// Generate a non-atomic read-modify-write sequence for storing to the emulated
367+
/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
368+
static void nonAtomicRMW(OpBuilder &builder, Location loc,
369+
MemRefValue linearizedMemref, Value linearizedIndex,
370+
VectorValue valueToStore, Value mask) {
372371
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
373372

374373
auto oneElemVecType =
@@ -427,9 +426,9 @@ namespace {
427426
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
428427
using OpConversionPattern::OpConversionPattern;
429428

430-
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
429+
ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
431430
: OpConversionPattern<vector::StoreOp>(context),
432-
useAtomicWrites_(useAtomicWrites) {}
431+
disableAtomicRMW(disableAtomicRMW) {}
433432

434433
LogicalResult
435434
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
@@ -557,6 +556,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
557556
auto subWidthStoreMaskType =
558557
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
559558

559+
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
560+
560561
// 1. Partial width store for the leading byte.
561562
// When the store address is not aligned to emulated width boundary, deal
562563
// with the unaligned part so that the rest elements are aligned to width
@@ -581,8 +582,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
581582
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
582583
frontSubWidthStoreElem, *foldedNumFrontPadElems);
583584

584-
auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore;
585-
586585
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
587586
cast<VectorValue>(value), frontMask.getResult());
588587
}
@@ -639,17 +638,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
639638
auto backMask = rewriter.create<arith::ConstantOp>(
640639
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
641640

642-
subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
643-
cast<VectorValue>(subWidthStorePart),
644-
backMask.getResult());
641+
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
642+
cast<VectorValue>(subWidthStorePart), backMask.getResult());
645643
}
646644

647645
rewriter.eraseOp(op);
648646
return success();
649647
}
650648

651649
private:
652-
const bool useAtomicWrites_;
650+
const bool disableAtomicRMW;
653651
};
654652

655653
//===----------------------------------------------------------------------===//
@@ -1962,7 +1960,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19621960

19631961
void vector::populateVectorNarrowTypeEmulationPatterns(
19641962
const arith::NarrowTypeEmulationConverter &typeConverter,
1965-
RewritePatternSet &patterns, bool useAtomicWrites) {
1963+
RewritePatternSet &patterns, bool disableAtomicRMW) {
19661964

19671965
// Populate `vector.*` conversion patterns.
19681966
// TODO: #119553 support atomicity
@@ -1973,7 +1971,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
19731971
// Populate `vector.*` store conversion patterns. The caller can choose
19741972
// to avoid emitting atomic operations and reduce it to load-modify-write
19751973
// sequence for stores if it is known there are no thread contentions.
1976-
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
1974+
patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
19771975
}
19781976

19791977
void vector::populateVectorNarrowTypeRewritePatterns(

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct TestEmulateNarrowTypePass
100100
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
101101
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
102102
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
103-
atomicStore);
103+
disableAtomicRMW);
104104

105105
if (failed(applyPartialConversion(op, target, std::move(patterns))))
106106
signalPassFailure();
@@ -120,10 +120,11 @@ struct TestEmulateNarrowTypePass
120120
llvm::cl::desc("disable memref type conversion (to test failures)"),
121121
llvm::cl::init(false)};
122122

123-
Option<bool> atomicStore{
124-
*this, "atomic-store",
125-
llvm::cl::desc("use atomic store instead of load-modify-write"),
126-
llvm::cl::init(true)};
123+
Option<bool> disableAtomicRMW{
124+
*this, "disable-atomic-rmw",
125+
llvm::cl::desc("disable atomic read-modify-write and prefer generating "
126+
"normal sequence"),
127+
llvm::cl::init(false)};
127128
};
128129
} // namespace
129130

0 commit comments

Comments
 (0)