diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a59f06f3c1ef1..7de4a6a315750 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Appends patterns for emulating vector operations over narrow types with ops -/// over wider types. +/// over wider types. The `disableAtomicRMW` indicates whether to use a normal +/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to +/// perform subbyte storing. void populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, bool disableAtomicRMW = false); /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 7ca88f1e0a0df..acd4ac3496789 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, /// /// Result: /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>) -static void atomicStore(OpBuilder &builder, Location loc, - MemRefValue linearizedMemref, Value storeIdx, - VectorValue valueToStore, Value mask) { +static void atomicRMW(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value storeIdx, + VectorValue valueToStore, Value mask) { assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); // Create an atomic load-modify-write region using @@ -363,6 +363,27 @@ static void atomicStore(OpBuilder &builder, Location loc, builder.create(loc, scalarMaskedValue); } +/// Generate a non-atomic read-modify-write sequence for storing to the emulated +/// type. It has similar logic to `atomicRMWStore`, but without atomicity. +static void nonAtomicRMW(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value linearizedIndex, + VectorValue valueToStore, Value mask) { + assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); + + auto oneElemVecType = + VectorType::get({1}, linearizedMemref.getType().getElementType()); + Value origVecValue = builder.create( + loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex}); + origVecValue = builder.create(loc, valueToStore.getType(), + origVecValue); + + Value maskedValue = + downcastSelectAndUpcast(builder, loc, valueToStore.getType(), + oneElemVecType, mask, valueToStore, origVecValue); + builder.create(loc, maskedValue, linearizedMemref, + linearizedIndex); +} + /// Extract `sliceNumElements` from source `vector` at `extractOffset`, /// and insert it into an empty vector at `insertOffset`. /// Inputs: @@ -405,6 +426,10 @@ namespace { struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; + ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) + : OpConversionPattern(context), + disableAtomicRMW(disableAtomicRMW) {} + LogicalResult matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -531,6 +556,8 @@ struct ConvertVectorStore final : OpConversionPattern { auto subWidthStoreMaskType = VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); + auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW; + // 1. Partial width store for the leading byte. // When the store address is not aligned to emulated width boundary, deal // with the unaligned part so that the rest elements are aligned to width @@ -555,8 +582,8 @@ struct ConvertVectorStore final : OpConversionPattern { extractSliceIntoByte(rewriter, loc, valueToStore, 0, frontSubWidthStoreElem, *foldedNumFrontPadElems); - atomicStore(rewriter, loc, memrefBase, currentDestIndex, - cast(value), frontMask.getResult()); + storeFunc(rewriter, loc, memrefBase, currentDestIndex, + cast(value), frontMask.getResult()); } if (currentSourceIndex >= origElements) { @@ -611,13 +638,16 @@ struct ConvertVectorStore final : OpConversionPattern { auto backMask = rewriter.create( loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); - atomicStore(rewriter, loc, memrefBase, currentDestIndex, - cast(subWidthStorePart), backMask.getResult()); + storeFunc(rewriter, loc, memrefBase, currentDestIndex, + cast(subWidthStorePart), backMask.getResult()); } rewriter.eraseOp(op); return success(); } + +private: + const bool disableAtomicRMW; }; //===----------------------------------------------------------------------===// @@ -1930,12 +1960,18 @@ struct RewriteVectorTranspose : OpRewritePattern { void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, bool disableAtomicRMW) { // Populate `vector.*` conversion patterns. - patterns.add( typeConverter, patterns.getContext()); + + // Populate `vector.*` store conversion patterns. The caller can choose + // to avoid emitting atomic operations and reduce it to read-modify-write + // sequence for stores if it is known there are no thread contentions. + patterns.insert(patterns.getContext(), disableAtomicRMW); } void vector::populateVectorNarrowTypeRewritePatterns( diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir new file mode 100644 index 0000000000000..1d6263535ae80 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir @@ -0,0 +1,128 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s + +// TODO: remove memref.alloc() in the tests to eliminate noises. +// memref.alloc exists here because sub-byte vector data types such as i2 +// are currently not supported as input arguments. + +///---------------------------------------------------------------------------------------- +/// vector.store +///---------------------------------------------------------------------------------------- + +func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { + %0 = memref.alloc() : memref<3x3xi2> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2> + return +} + +// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)), +// into bytes [1:2] from a 3-byte output memref. Due to partial storing, +// both bytes are accessed partially through masking. + +// CHECK: func @vector_store_i2_const_index_two_partial_stores( +// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index + +// Part 1 RMW sequence +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]] +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]] + +// Part 2 RMW sequence +// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2> +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1> +// CHECK: %[[LOAD2:.+]] = vector.load +// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]] +// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]] +// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]] + + +// ----- + +func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) { + %0 = memref.alloc() : memref<3x7xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2> + return +} + +// In this example, emit two RMW stores and one full-width store. + +// CHECK: func @vector_store_i2_two_partial_one_full_stores( +// CHECK-SAME: %[[ARG0:.+]]: +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [3], strides = [1]} +// First sub-width RMW: +// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]] +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] + +// Full-width store: +// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]] +// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]] +// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]] + +// Second sub-width RMW: +// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]] +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]] +// CHECK-SAME: {offsets = [0], strides = [1]} +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> +// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]] +// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]] +// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]] +// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]] + +// ----- + +func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) { + %0 = memref.alloc() : memref<4x1xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2> + return +} + +// in this test, only emit partial RMW store as the store is within one byte. + +// CHECK: func @vector_store_i2_const_index_one_partial_store( +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8> +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]] diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 89cb8e0bde875..6fc974200c6f3 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -369,10 +369,9 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { return } -// In this example, emit 2 atomic RMWs. -// -// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes: -// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out) +// Emit two atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)), +// into bytes [1:2] from a 3-byte output memref. Due to partial storing, +// both bytes are accessed partially through masking. // CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores( // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index 7401e470ed4f2..ba2ea40e83d96 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); - vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns); + vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns, + disableAtomicRMW); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); @@ -118,6 +119,12 @@ struct TestEmulateNarrowTypePass *this, "skip-memref-type-conversion", llvm::cl::desc("disable memref type conversion (to test failures)"), llvm::cl::init(false)}; + + Option disableAtomicRMW{ + *this, "disable-atomic-rmw", + llvm::cl::desc("disable atomic read-modify-write and prefer generating " + "normal sequence"), + llvm::cl::init(false)}; }; } // namespace