Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Support non-atomic RMW option for emulated vector stores #124887

Merged
merged 8 commits into from
Feb 6, 2025

Conversation

lialan
Copy link
Member

@lialan lialan commented Jan 29, 2025

This patch is a followup of the previous one: #115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Alan Li (lialan)

Changes

This patch is a followup of the previous one: #115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.


Full diff: https://github.com/llvm/llvm-project/pull/124887.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+58-6)
  • (added) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir (+119)
  • (modified) mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp (+7-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..43478aacb50a14 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                PatternBenefit benefit = 1);
 
 /// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. The `useAtomicWrites` indicates whether to use
+/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
+/// rmw sequence otherwise.
 void populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, bool useAtomicWrites = true);
 
 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
 /// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7ca88f1e0a0df9..c848d3c0ca98aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -363,6 +363,30 @@ static void atomicStore(OpBuilder &builder, Location loc,
   builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
 }
 
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+/// It has similar logic to `atomicStore`, but without the atomicity.
+static void rmwStore(OpBuilder &builder, Location loc,
+                     MemRefValue linearizedMemref, Value linearizedIndex,
+                     VectorValue valueToStore, Value mask) {
+  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+  // Load the original value from memory, and cast it to the original element
+  // type.
+  auto oneElemVecType =
+      VectorType::get({1}, linearizedMemref.getType().getElementType());
+  Value origVecValue = builder.create<vector::LoadOp>(
+      loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
+  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+                                                   origVecValue);
+
+  // Construct the final masked value and yield it.
+  Value maskedValue =
+      downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+                              oneElemVecType, mask, valueToStore, origVecValue);
+  builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
+                                  linearizedIndex);
+}
+
 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
 /// and insert it into an empty vector at `insertOffset`.
 /// Inputs:
@@ -405,6 +429,10 @@ namespace {
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
+  ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+      : OpConversionPattern<vector::StoreOp>(context),
+        useAtomicWrites_(useAtomicWrites) {}
+
   LogicalResult
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -555,8 +583,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
           extractSliceIntoByte(rewriter, loc, valueToStore, 0,
                                frontSubWidthStoreElem, *foldedNumFrontPadElems);
 
-      atomicStore(rewriter, loc, memrefBase, currentDestIndex,
-                  cast<VectorValue>(value), frontMask.getResult());
+      subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
+                            cast<VectorValue>(value), frontMask.getResult());
     }
 
     if (currentSourceIndex >= origElements) {
@@ -611,13 +639,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       auto backMask = rewriter.create<arith::ConstantOp>(
           loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
 
-      atomicStore(rewriter, loc, memrefBase, currentDestIndex,
-                  cast<VectorValue>(subWidthStorePart), backMask.getResult());
+      subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
+                            cast<VectorValue>(subWidthStorePart),
+                            backMask.getResult());
     }
 
     rewriter.eraseOp(op);
     return success();
   }
+
+  /// Store a subbyte-sized value to memory, with a mask. Depending on the
+  /// configuration, it could be an atomic store or a non-atomic RMW sequence.
+  template <typename... Args>
+  void subEmulatedWidthStore(Args &&...args) const {
+    static_assert(
+        std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+        "`atomicStore` and `rmwStore` must have same signature, as per "
+        "the design to keep the code clean, which one to call is "
+        "determined by the `useAtomicWrites` flag.");
+    std::function<decltype(atomicStore)> storeFunc =
+        useAtomicWrites_ ? atomicStore : rmwStore;
+    storeFunc(std::forward<Args>(args)...);
+  }
+
+private:
+  const bool useAtomicWrites_;
 };
 
 //===----------------------------------------------------------------------===//
@@ -1930,12 +1976,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
 
 void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, bool useAtomicWrites) {
 
   // Populate `vector.*` conversion patterns.
-  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+  // TODO: #119553 support atomicity
+  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
                ConvertVectorMaskedStore, ConvertVectorTransferRead>(
       typeConverter, patterns.getContext());
+
+  // Populate `vector.*` store conversion patterns. The caller can choose
+  // to avoid emitting atomic operations and reduce it to load-modify-write
+  // sequence for stores if it is known there are no thread contentions.
+  patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
new file mode 100644
index 00000000000000..79f8869d043ee3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s
+
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK: func @vector_store_i2_const_index_two_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// Part 1 RMW sequence
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Part 2 RMW sequence
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[LOAD2:.+]] = vector.load
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
+
+
+// -----
+
+func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// CHECK: func @vector_store_i2_rmw(
+// CHECK-SAME: %[[ARG0:.+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// First sub-width RMW:
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Full-width store:
+// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
+// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
+
+// Second sub-width RMW:
+// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
+// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]]
+// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
+// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]
+
+// -----
+
+func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this test, only emit 1 rmw store
+// CHECK: func @vector_store_i2_single_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
+
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 7401e470ed4f2c..9a3fac623fbd7d 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
 
     arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
     memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
-    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
+                                                      atomicStore);
 
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
@@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
       *this, "skip-memref-type-conversion",
       llvm::cl::desc("disable memref type conversion (to test failures)"),
       llvm::cl::init(false)};
+
+  Option<bool> atomicStore{
+      *this, "atomic-store",
+      llvm::cl::desc("use atomic store instead of load-modify-write"),
+      llvm::cl::init(true)};
 };
 } // namespace
 

@lialan
Copy link
Member Author

lialan commented Feb 3, 2025

Bump. Can you guys take a look? @hanhanW @dcaballe @banach-space

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I reviewed it in the previous PR, and it got split out because of other review comments, right? In this case, please wait for an approval from one of other reviewers.

@lialan
Copy link
Member Author

lialan commented Feb 3, 2025

I think I reviewed it in the previous PR, and it got split out because of other review comments, right? In this case, please wait for an approval from one of other reviewers.

Yes you are right.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I've been away and have a bit of catching up to do.

Btw, thanks for sending this as a separate PR - makes much more sense and is easier to review 🙏🏻

@lialan lialan requested a review from banach-space February 4, 2025 05:53
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates so far!

Comment on lines 366 to 367
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
/// It has similar logic to `atomicRMWStore`, but without atomicity.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an answer yet, but we may need to replace sub-byte with emulatedType - not all "emulated" types within this file are sub-bytes. We also use e.g. i32 to emulate i8:

func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
return %1 : vector<4xi8>
}
// Expect no conversions, i8 is supported.
// CHECK: func @vector_load_i8(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
// CHECK-NEXT: [[L:%.+]] = vector.load %[[ALLOC]][%[[ARG0]], %[[ARG1]]] : memref<3x4xi8>, vector<4xi8>
// CHECK-NEXT: return
// CHECK32: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
// CHECK32: func @vector_load_i8(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[VECLOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VECLOAD]] : vector<1xi32> to vector<4xi8>
// CHECK32: return %[[VEC_I4]]

Either we refrain from using sub-byte or force this logic to only be available for sub-bytes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. I can see there are many places still using sub-byte descriptions. Let's update those to be consistent later.

@lialan lialan force-pushed the lialan/non_atomic branch from 64b354a to 4fbbcbe Compare February 4, 2025 13:20
@lialan lialan force-pushed the lialan/non_atomic branch from 4fbbcbe to c9e2754 Compare February 4, 2025 21:09
@lialan lialan requested a review from banach-space February 4, 2025 21:12
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, some final nits.

@lialan lialan requested a review from banach-space February 6, 2025 17:43
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % minor nit :)

Thanks for all the improvements, Alan!

@lialan lialan merged commit f0e1857 into llvm:main Feb 6, 2025
6 of 7 checks passed
@lialan lialan deleted the lialan/non_atomic branch February 6, 2025 21:22
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…#124887)

This patch is a followup of the previous one: llvm#115922, It adds an option
to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants