Skip to content

Commit f0e1857

Browse files
authored
[MLIR] Support non-atomic RMW option for emulated vector stores (#124887)
This patch is a followup of the previous one: #115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
1 parent da083e2 commit f0e1857

File tree

5 files changed

+188
-16
lines changed

5 files changed

+188
-16
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
364364
PatternBenefit benefit = 1);
365365

366366
/// Appends patterns for emulating vector operations over narrow types with ops
367-
/// over wider types.
367+
/// over wider types. The `disableAtomicRMW` indicates whether to use a normal
368+
/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
369+
/// perform subbyte storing.
368370
void populateVectorNarrowTypeEmulationPatterns(
369371
const arith::NarrowTypeEmulationConverter &typeConverter,
370-
RewritePatternSet &patterns);
372+
RewritePatternSet &patterns, bool disableAtomicRMW = false);
371373

372374
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
373375
/// vector operations comprising `shuffle` and `bitwise` ops.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

+45-9
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
342342
///
343343
/// Result:
344344
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
345-
static void atomicStore(OpBuilder &builder, Location loc,
346-
MemRefValue linearizedMemref, Value storeIdx,
347-
VectorValue valueToStore, Value mask) {
345+
static void atomicRMW(OpBuilder &builder, Location loc,
346+
MemRefValue linearizedMemref, Value storeIdx,
347+
VectorValue valueToStore, Value mask) {
348348
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
349349

350350
// Create an atomic load-modify-write region using
@@ -371,6 +371,27 @@ static void atomicStore(OpBuilder &builder, Location loc,
371371
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
372372
}
373373

374+
/// Generate a non-atomic read-modify-write sequence for storing to the emulated
375+
/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
376+
static void nonAtomicRMW(OpBuilder &builder, Location loc,
377+
MemRefValue linearizedMemref, Value linearizedIndex,
378+
VectorValue valueToStore, Value mask) {
379+
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
380+
381+
auto oneElemVecType =
382+
VectorType::get({1}, linearizedMemref.getType().getElementType());
383+
Value origVecValue = builder.create<vector::LoadOp>(
384+
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
385+
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
386+
origVecValue);
387+
388+
Value maskedValue =
389+
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
390+
oneElemVecType, mask, valueToStore, origVecValue);
391+
builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
392+
linearizedIndex);
393+
}
394+
374395
/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
375396
/// and insert it into an empty vector at `insertOffset`.
376397
/// Inputs:
@@ -415,6 +436,10 @@ namespace {
415436
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
416437
using OpConversionPattern::OpConversionPattern;
417438

439+
ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
440+
: OpConversionPattern<vector::StoreOp>(context),
441+
disableAtomicRMW(disableAtomicRMW) {}
442+
418443
LogicalResult
419444
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
420445
ConversionPatternRewriter &rewriter) const override {
@@ -544,6 +569,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
544569
auto subWidthStoreMaskType =
545570
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
546571

572+
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573+
547574
// 1. Partial width store for the leading byte.
548575
// When the store address is not aligned to emulated width boundary, deal
549576
// with the unaligned part so that the rest elements are aligned to width
@@ -568,8 +595,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
568595
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
569596
frontSubWidthStoreElem, *foldedNumFrontPadElems);
570597

571-
atomicStore(rewriter, loc, memrefBase, currentDestIndex,
572-
cast<VectorValue>(value), frontMask.getResult());
598+
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
599+
cast<VectorValue>(value), frontMask.getResult());
573600
}
574601

575602
if (currentSourceIndex >= origElements) {
@@ -624,13 +651,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
624651
auto backMask = rewriter.create<arith::ConstantOp>(
625652
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
626653

627-
atomicStore(rewriter, loc, memrefBase, currentDestIndex,
628-
cast<VectorValue>(subWidthStorePart), backMask.getResult());
654+
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
655+
cast<VectorValue>(subWidthStorePart), backMask.getResult());
629656
}
630657

631658
rewriter.eraseOp(op);
632659
return success();
633660
}
661+
662+
private:
663+
const bool disableAtomicRMW;
634664
};
635665

636666
//===----------------------------------------------------------------------===//
@@ -1969,12 +1999,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
19691999
// The emulated type is inferred from the converted memref type.
19702000
void vector::populateVectorNarrowTypeEmulationPatterns(
19712001
const arith::NarrowTypeEmulationConverter &typeConverter,
1972-
RewritePatternSet &patterns) {
2002+
RewritePatternSet &patterns, bool disableAtomicRMW) {
19732003

19742004
// Populate `vector.*` conversion patterns.
1975-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
2005+
// TODO: #119553 support atomicity
2006+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
19762007
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
19772008
typeConverter, patterns.getContext());
2009+
2010+
// Populate `vector.*` store conversion patterns. The caller can choose
2011+
// to avoid emitting atomic operations and reduce it to read-modify-write
2012+
// sequence for stores if it is known there are no thread contentions.
2013+
patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
19782014
}
19792015

19802016
void vector::populateVectorNarrowTypeRewritePatterns(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s
2+
3+
// TODO: remove memref.alloc() in the tests to eliminate noises.
4+
// memref.alloc exists here because sub-byte vector data types such as i2
5+
// are currently not supported as input arguments.
6+
7+
///----------------------------------------------------------------------------------------
8+
/// vector.store
9+
///----------------------------------------------------------------------------------------
10+
11+
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
12+
%0 = memref.alloc() : memref<3x3xi2>
13+
%c0 = arith.constant 0 : index
14+
%c2 = arith.constant 2 : index
15+
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
16+
return
17+
}
18+
19+
// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
20+
// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
21+
// both bytes are accessed partially through masking.
22+
23+
// CHECK: func @vector_store_i2_const_index_two_partial_stores(
24+
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
25+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
26+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
27+
28+
// Part 1 RMW sequence
29+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
30+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
31+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
32+
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
33+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
34+
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
35+
// CHECK: %[[LOAD:.+]] = vector.load
36+
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
37+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]]
38+
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]]
39+
// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]]
40+
41+
// Part 2 RMW sequence
42+
// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
43+
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
44+
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
45+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
46+
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
47+
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
48+
// CHECK: %[[LOAD2:.+]] = vector.load
49+
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
50+
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
51+
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
52+
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
53+
54+
55+
// -----
56+
57+
func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
58+
%0 = memref.alloc() : memref<3x7xi2>
59+
%c0 = arith.constant 0 : index
60+
%c1 = arith.constant 1 : index
61+
vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
62+
return
63+
}
64+
65+
// In this example, emit two RMW stores and one full-width store.
66+
67+
// CHECK: func @vector_store_i2_two_partial_one_full_stores(
68+
// CHECK-SAME: %[[ARG0:.+]]:
69+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
70+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
71+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
72+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
73+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
74+
// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
75+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
76+
// CHECK-SAME: {offsets = [3], strides = [1]}
77+
// First sub-width RMW:
78+
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
79+
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
80+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
81+
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
82+
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
83+
84+
// Full-width store:
85+
// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
86+
// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
87+
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
88+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
89+
// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
90+
91+
// Second sub-width RMW:
92+
// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
93+
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
94+
// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
95+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
96+
// CHECK-SAME: {offsets = [0], strides = [1]}
97+
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
98+
// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
99+
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]]
100+
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
101+
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
102+
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]]
103+
104+
// -----
105+
106+
func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
107+
%0 = memref.alloc() : memref<4x1xi2>
108+
%c0 = arith.constant 0 : index
109+
%c1 = arith.constant 1 : index
110+
vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
111+
return
112+
}
113+
114+
// in this test, only emit partial RMW store as the store is within one byte.
115+
116+
// CHECK: func @vector_store_i2_const_index_one_partial_store(
117+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
118+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
119+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
120+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
121+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
122+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
123+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
124+
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
125+
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
126+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
127+
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
128+
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

+3-4
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,9 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
369369
return
370370
}
371371

372-
// In this example, emit 2 atomic RMWs.
373-
//
374-
// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
375-
// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
372+
// Emit two atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
373+
// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
374+
// both bytes are accessed partially through masking.
376375

377376
// CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores(
378377
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
9999

100100
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
101101
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
102-
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
102+
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
103+
disableAtomicRMW);
103104

104105
if (failed(applyPartialConversion(op, target, std::move(patterns))))
105106
signalPassFailure();
@@ -118,6 +119,12 @@ struct TestEmulateNarrowTypePass
118119
*this, "skip-memref-type-conversion",
119120
llvm::cl::desc("disable memref type conversion (to test failures)"),
120121
llvm::cl::init(false)};
122+
123+
Option<bool> disableAtomicRMW{
124+
*this, "disable-atomic-rmw",
125+
llvm::cl::desc("disable atomic read-modify-write and prefer generating "
126+
"normal sequence"),
127+
llvm::cl::init(false)};
121128
};
122129
} // namespace
123130

0 commit comments

Comments
 (0)