-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Support non-atomic RMW option for emulated vector stores #124887
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Alan Li (lialan) ChangesThis patch is a followup of the previous one: #115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw. Full diff: https://github.com/llvm/llvm-project/pull/124887.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..43478aacb50a14 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. The `useAtomicWrites` indicates whether to use
+/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
+/// rmw sequence otherwise.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, bool useAtomicWrites = true);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7ca88f1e0a0df9..c848d3c0ca98aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -363,6 +363,30 @@ static void atomicStore(OpBuilder &builder, Location loc,
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+/// It has similar logic to `atomicStore`, but without the atomicity.
+static void rmwStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value linearizedIndex,
+ VectorValue valueToStore, Value mask) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Load the original value from memory, and cast it to the original element
+ // type.
+ auto oneElemVecType =
+ VectorType::get({1}, linearizedMemref.getType().getElementType());
+ Value origVecValue = builder.create<vector::LoadOp>(
+ loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue =
+ downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+ oneElemVecType, mask, valueToStore, origVecValue);
+ builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
+ linearizedIndex);
+}
+
/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
/// and insert it into an empty vector at `insertOffset`.
/// Inputs:
@@ -405,6 +429,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context),
+ useAtomicWrites_(useAtomicWrites) {}
+
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -555,8 +583,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
- atomicStore(rewriter, loc, memrefBase, currentDestIndex,
- cast<VectorValue>(value), frontMask.getResult());
+ subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
+ cast<VectorValue>(value), frontMask.getResult());
}
if (currentSourceIndex >= origElements) {
@@ -611,13 +639,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
- atomicStore(rewriter, loc, memrefBase, currentDestIndex,
- cast<VectorValue>(subWidthStorePart), backMask.getResult());
+ subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
+ cast<VectorValue>(subWidthStorePart),
+ backMask.getResult());
}
rewriter.eraseOp(op);
return success();
}
+
+ /// Store a subbyte-sized value to memory, with a mask. Depending on the
+ /// configuration, it could be an atomic store or a non-atomic RMW sequence.
+ template <typename... Args>
+ void subEmulatedWidthStore(Args &&...args) const {
+ static_assert(
+ std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+ "`atomicStore` and `rmwStore` must have same signature, as per "
+ "the design to keep the code clean, which one to call is "
+ "determined by the `useAtomicWrites` flag.");
+ std::function<decltype(atomicStore)> storeFunc =
+ useAtomicWrites_ ? atomicStore : rmwStore;
+ storeFunc(std::forward<Args>(args)...);
+ }
+
+private:
+ const bool useAtomicWrites_;
};
//===----------------------------------------------------------------------===//
@@ -1930,12 +1976,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, bool useAtomicWrites) {
// Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+ // TODO: #119553 support atomicity
+ patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
+
+ // Populate `vector.*` store conversion patterns. The caller can choose
+ // to avoid emitting atomic operations and reduce it to load-modify-write
+ // sequence for stores if it is known there are no thread contentions.
+ patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
new file mode 100644
index 00000000000000..79f8869d043ee3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s
+
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK: func @vector_store_i2_const_index_two_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// Part 1 RMW sequence
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Part 2 RMW sequence
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[LOAD2:.+]] = vector.load
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
+
+
+// -----
+
+func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// CHECK: func @vector_store_i2_rmw(
+// CHECK-SAME: %[[ARG0:.+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// First sub-width RMW:
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Full-width store:
+// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
+// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
+
+// Second sub-width RMW:
+// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
+// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]]
+// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
+// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]
+
+// -----
+
+func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// in this test, only emit 1 rmw store
+// CHECK: func @vector_store_i2_single_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
+
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 7401e470ed4f2c..9a3fac623fbd7d 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
- vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
+ atomicStore);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
@@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};
+
+ Option<bool> atomicStore{
+ *this, "atomic-store",
+ llvm::cl::desc("use atomic store instead of load-modify-write"),
+ llvm::cl::init(true)};
};
} // namespace
|
Bump. Can you guys take a look? @hanhanW @dcaballe @banach-space |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I reviewed it in the previous PR, and it got split out because of other review comments, right? In this case, please wait for an approval from one of other reviewers.
Yes you are right. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, I've been away and have a bit of catching up to do.
Btw, thanks for sending this as a separate PR - makes much more sense and is easier to review 🙏🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates so far!
/// Generate a non-atomic read-modify-write sequence for subbyte storing. | ||
/// It has similar logic to `atomicRMWStore`, but without atomicity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have an answer yet, but we may need to replace sub-byte
with emulatedType
- not all "emulated" types within this file are sub-bytes. We also use e.g. i32
to emulate i8
:
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> { | |
%0 = memref.alloc() : memref<3x4xi8> | |
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8> | |
return %1 : vector<4xi8> | |
} | |
// Expect no conversions, i8 is supported. | |
// CHECK: func @vector_load_i8( | |
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) | |
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8> | |
// CHECK-NEXT: [[L:%.+]] = vector.load %[[ALLOC]][%[[ARG0]], %[[ARG1]]] : memref<3x4xi8>, vector<4xi8> | |
// CHECK-NEXT: return | |
// CHECK32: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)> | |
// CHECK32: func @vector_load_i8( | |
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) | |
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> | |
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] | |
// CHECK32: %[[VECLOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32> | |
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VECLOAD]] : vector<1xi32> to vector<4xi8> | |
// CHECK32: return %[[VEC_I4]] |
Either we refrain from using sub-byte or force this logic to only be available for sub-bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. I can see there are many places still using sub-byte
descriptions. Let's update those to be consistent later.
64b354a
to
4fbbcbe
Compare
4fbbcbe
to
c9e2754
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, some final nits.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % minor nit :)
Thanks for all the improvements, Alan!
…#124887) This patch is a followup of the previous one: llvm#115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
This patch is a followup of the previous one: #115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.