diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index a59f06f3c1ef1..64bb3a2204cfd 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Appends patterns for emulating vector operations over narrow types with ops -/// over wider types. +/// over wider types. `useAtomicWrites` indicates whether to use atomic +/// operations in the places where thread contention is possible. void populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, bool useAtomicWrites = true); /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of /// vector operations comprising `shuffle` and `bitwise` ops. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index dc8bab325184b..5eb85dce146e0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -33,6 +33,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include @@ -208,13 +209,10 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for /// emitting `vector.extract_strided_slice`. static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, - VectorType extractType, Value source, - int64_t frontOffset, + Value source, int64_t frontOffset, int64_t subvecSize) { auto vectorType = cast(source.getType()); - assert((vectorType.getRank() == 1 && extractType.getRank() == 1) && - "expected 1-D source and destination types"); - (void)vectorType; + assert(vectorType.getRank() == 1 && "expected 1-D source types"); assert(frontOffset + subvecSize <= vectorType.getNumElements() && "subvector out of bounds"); @@ -225,9 +223,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, auto offsets = rewriter.getI64ArrayAttr({frontOffset}); auto sizes = rewriter.getI64ArrayAttr({subvecSize}); auto strides = rewriter.getI64ArrayAttr({1}); + + auto resultVectorType = + VectorType::get({subvecSize}, vectorType.getElementType()); return rewriter - .create(loc, extractType, source, offsets, - sizes, strides) + .create(loc, resultVectorType, source, + offsets, sizes, strides) ->getResult(0); } @@ -306,6 +307,73 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, newLoad); } +/// Atomically store a subbyte-sized value to memory, with a mask. +static void atomicStore(OpBuilder &rewriter, Location loc, + TypedValue emulatedMemref, + Value emulatedIndex, TypedValue value, + Value mask, int64_t scale) { + auto atomicOp = rewriter.create( + loc, emulatedMemref, ValueRange{emulatedIndex}); + OpBuilder builder = + OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener()); + Value origValue = atomicOp.getCurrentValue(); + + // i8 -> vector type <1xi8> then <1xi8> -> + auto oneVectorType = VectorType::get({1}, origValue.getType()); + auto fromElem = builder.create(loc, oneVectorType, + ValueRange{origValue}); + auto vectorBitCast = + builder.create(loc, value.getType(), fromElem); + + auto select = + builder.create(loc, mask, value, vectorBitCast); + auto bitcast2 = builder.create(loc, oneVectorType, select); + auto extract = builder.create(loc, bitcast2, 0); + builder.create(loc, extract.getResult()); +} + +/// Generate a non-atomic read-modify-write sequence for subbyte storing. +static void rmwStore(OpBuilder &rewriter, Location loc, + TypedValue emulatedMemref, Value emulatedIndex, + TypedValue value, Value mask, + int64_t numSrcElemsPerDest) { + auto emulatedIOType = + VectorType::get({1}, emulatedMemref.getType().getElementType()); + auto elemLoad = rewriter.create( + loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex}); + auto fromBitcast = rewriter.create( + loc, + VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()), + elemLoad); + auto select = rewriter.create(loc, mask, fromBitcast, value); + auto toBitcast = + rewriter.create(loc, emulatedIOType, select); + rewriter.create(loc, toBitcast, emulatedMemref, + emulatedIndex); +} + +static_assert(std::is_same_v && + "`atomicStore` and `rmwStore` must have same function type."); + +// Extract a slice of a vector, and insert it into a byte vector. +static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, + Location loc, TypedValue vector, + int64_t sliceOffset, int64_t sliceNumElements, + int64_t byteOffset) { + auto vectorElementType = vector.getType().getElementType(); + assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 && + "vector element must be a valid sub-byte type"); + auto scale = 8 / vectorElementType.getIntOrFloatBitWidth(); + auto emptyByteVector = rewriter.create( + loc, VectorType::get({scale}, vectorElementType), + rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType))); + auto extracted = staticallyExtractSubvector(rewriter, loc, vector, + sliceOffset, sliceNumElements); + auto inserted = staticallyInsertSubvector(rewriter, loc, extracted, + emptyByteVector, byteOffset); + return inserted; +} + namespace { //===----------------------------------------------------------------------===// @@ -315,6 +383,10 @@ namespace { struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; + ConvertVectorStore(MLIRContext *context, bool useAtomicWrites) + : OpConversionPattern(context), + useAtomicWrites_(useAtomicWrites) {} + LogicalResult matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -326,7 +398,8 @@ struct ConvertVectorStore final : OpConversionPattern { auto loc = op.getLoc(); auto convertedType = cast(adaptor.getBase().getType()); - Type oldElementType = op.getValueToStore().getType().getElementType(); + auto valueToStore = op.getValueToStore(); + Type oldElementType = valueToStore.getType().getElementType(); Type newElementType = convertedType.getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); int dstBits = newElementType.getIntOrFloatBitWidth(); @@ -335,7 +408,7 @@ struct ConvertVectorStore final : OpConversionPattern { return rewriter.notifyMatchFailure( op, "only dstBits % srcBits == 0 supported"); } - int scale = dstBits / srcBits; + int numSrcElemsPerDest = dstBits / srcBits; // Adjust the number of elements to store when emulating narrow types. // Here only the 1-D vector store is considered, and the N-D memref types @@ -350,15 +423,15 @@ struct ConvertVectorStore final : OpConversionPattern { // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>, // vector<4xi8> - auto origElements = op.getValueToStore().getType().getNumElements(); - if (origElements % scale != 0) - return failure(); + auto origElements = valueToStore.getType().getNumElements(); + bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); OpFoldResult linearizedIndices; - std::tie(std::ignore, linearizedIndices) = + memref::LinearizedMemRefInfo linearizedInfo; + std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( rewriter, loc, srcBits, dstBits, stridedMetadata.getConstifiedMixedOffset(), @@ -366,16 +439,138 @@ struct ConvertVectorStore final : OpConversionPattern { stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); - auto numElements = origElements / scale; - auto bitCast = rewriter.create( - loc, VectorType::get(numElements, newElementType), - op.getValueToStore()); + auto foldedNumFrontPadElems = + isUnalignedEmulation + ? getConstantIntValue(linearizedInfo.intraDataOffset) + : 0; + + if (!foldedNumFrontPadElems) { + // Unimplemented case for dynamic front padding size != 0 + return failure(); + } + + TypedValue emulatedMemref = + cast>(adaptor.getBase()); + + // Shortcut: conditions when subbyte store at the front is not needed: + // 1. The source vector size is multiple of byte size + // 2. The address of the store is aligned to the emulated width boundary + if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) { + auto numElements = origElements / numSrcElemsPerDest; + auto bitCast = rewriter.create( + loc, VectorType::get(numElements, newElementType), + op.getValueToStore()); + rewriter.replaceOpWithNewOp( + op, bitCast.getResult(), emulatedMemref, + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); + return llvm::success(); + } + + // The index into the target memref we are storing to + Value currentDestIndex = + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); + auto constantOne = rewriter.create(loc, 1); + auto subWidthStoreMaskType = + VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); + // The index into the source vector we are currently processing + auto currentSourceIndex = 0; + + // 1. Partial width store for the first byte, when the store address is not + // aligned to emulated width boundary, deal with the unaligned part so that + // the rest elements are aligned to width boundary. + auto frontSubWidthStoreElem = + (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest; + if (frontSubWidthStoreElem != 0) { + auto frontMaskValues = llvm::SmallVector(numSrcElemsPerDest, false); + if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) { + std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems, + origElements, true); + frontSubWidthStoreElem = origElements; + } else { + std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem, + *foldedNumFrontPadElems, true); + } + auto frontMask = rewriter.create( + loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); - rewriter.replaceOpWithNewOp( - op, bitCast.getResult(), adaptor.getBase(), - getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); + currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems); + auto value = extractSliceIntoByte( + rewriter, loc, cast>(valueToStore), 0, + frontSubWidthStoreElem, *foldedNumFrontPadElems); + + subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex, + cast>(value), + frontMask.getResult(), numSrcElemsPerDest); + + currentDestIndex = rewriter.create( + loc, rewriter.getIndexType(), currentDestIndex, constantOne); + } + + if (currentSourceIndex >= origElements) { + rewriter.eraseOp(op); + return success(); + } + + // 2. Full width store. After the previous step, the store address is + // aligned to the emulated width boundary. + int64_t fullWidthStoreSize = + (origElements - currentSourceIndex) / numSrcElemsPerDest; + int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest; + if (fullWidthStoreSize != 0) { + auto fullWidthStorePart = staticallyExtractSubvector( + rewriter, loc, cast>(valueToStore), + currentSourceIndex, numNonFullWidthElements); + + auto originType = dyn_cast(fullWidthStorePart.getType()); + auto memrefElemType = + dyn_cast(emulatedMemref.getType()).getElementType(); + auto storeType = VectorType::get( + {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType); + auto bitCast = rewriter.create(loc, storeType, + fullWidthStorePart); + rewriter.create(loc, bitCast.getResult(), emulatedMemref, + currentDestIndex); + + currentSourceIndex += numNonFullWidthElements; + currentDestIndex = rewriter.create( + loc, rewriter.getIndexType(), currentDestIndex, + rewriter.create(loc, fullWidthStoreSize)); + } + + // 3. Deal with trailing elements that are aligned to the emulated width, + // but their length is smaller than the emulated width. + auto remainingElements = origElements - currentSourceIndex; + if (remainingElements != 0) { + auto subWidthStorePart = extractSliceIntoByte( + rewriter, loc, cast>(valueToStore), + currentSourceIndex, remainingElements, 0); + + // Generate back mask + auto maskValues = llvm::SmallVector(numSrcElemsPerDest, 0); + std::fill_n(maskValues.begin(), remainingElements, 1); + auto backMask = rewriter.create( + loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); + + subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex, + cast>(subWidthStorePart), + backMask.getResult(), numSrcElemsPerDest); + } + + rewriter.eraseOp(op); return success(); } + + /// Store a subbyte-sized value to memory, with a mask. Depending on the + /// configuration, it could be an atomic store or an RMW sequence. + template + void subEmulatedWidthStore(Args &&...args) const { + std::function storeFunc = + useAtomicWrites_ ? atomicStore : rmwStore; + storeFunc(std::forward(args)...); + } + +private: + const bool useAtomicWrites_; }; //===----------------------------------------------------------------------===// @@ -581,9 +776,8 @@ struct ConvertVectorLoad final : OpConversionPattern { rewriter, loc, dyn_cast>(result), resultVector, linearizedInfo.intraDataOffset, origElements); } else if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + result = staticallyExtractSubvector( + rewriter, loc, result, *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); return success(); @@ -742,9 +936,8 @@ struct ConvertVectorMaskedLoad final rewriter, loc, dyn_cast>(result), op.getPassThru(), linearizedInfo.intraDataOffset, origElements); } else if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + result = staticallyExtractSubvector( + rewriter, loc, result, *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); @@ -827,9 +1020,8 @@ struct ConvertVectorTransferRead final linearizedInfo.intraDataOffset, origElements); } else if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + result = staticallyExtractSubvector( + rewriter, loc, result, *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); @@ -1574,12 +1766,17 @@ struct RewriteVectorTranspose : OpRewritePattern { void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, bool useAtomicWrites) { - // Populate `vector.*` conversion patterns. - patterns.add( typeConverter, patterns.getContext()); + + // Populate `vector.*` store conversion patterns. The caller can choose + // to avoid emitting atomic operations and reduce it to load-modify-write + // sequence for stores if it is known there are no thread contentions. + patterns.insert(patterns.getContext(), useAtomicWrites); } void vector::populateVectorNarrowTypeRewritePatterns( diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir new file mode 100644 index 0000000000000..fa4d9cb5e4d4c --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir @@ -0,0 +1,104 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s + +// TODO: remove memref.alloc() in the tests to eliminate noises. +// memref.alloc exists here because sub-byte vector data types such as i2 +// are currently not supported as input arguments. + +func.func @vector_store_i2_const_rmw(%arg0: vector<3xi2>) { + %0 = memref.alloc() : memref<3x3xi2> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2> + return +} +// CHECK: func @vector_store_i2_const_rmw( +// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[LOAD:.+]] = vector.load + +// Actual part to do RMW sequence +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] + +// ----- + +func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) { + %0 = memref.alloc() : memref<3x7xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2> + return +} + +// CHECK: func @vector_store_i2_atomic( +// CHECK-SAME: %[[ARG0:.+]]: +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [3], strides = [1]} +// First sub-width RMW: +// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]] +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]] + +// Full-width store: +// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]] +// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]] +// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]] + +// Second sub-width RMW: +// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]] +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]] +// CHECK-SAME: {offsets = [0], strides = [1]} +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> +// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]] +// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]] +// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[UPCAST1]], %[[INSERT2]] +// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]] +// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]] + +// ----- + +func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) { + %0 = memref.alloc() : memref<4x1xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2> + return +} + +// in this test, only emit 1 rmw store +// CHECK: func @vector_store_i2_single_rmw( +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8> +// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]] +// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]] +// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]] + diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index b1a0d4f924f3c..ef04ab17ef755 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -287,3 +287,93 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) // CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]] // CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> // CHECK: return %[[RESULT]] : vector<5xi2> + +///---------------------------------------------------------------------------------------- +/// vector.store +///---------------------------------------------------------------------------------------- + +// ----- + +func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) { + %0 = memref.alloc() : memref<3x7xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2> + return +} + +// in this example, emit 2 atomic stores and 1 non-atomic store + +// CHECK: func @vector_store_i2_atomic( +// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>) +// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> + +// first atomic store +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> { +// CHECK: %[[ARG:.+]]: i8): +// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8 + +// non atomic store part +// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2> +// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8> + +// second atomic store +// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index +// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2> +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> { +// CHECK: %[[ARG2:.+]]: i8): +// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8> +// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : +// CHECK-SAME: vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8 + +// ----- + +func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) { + %0 = memref.alloc() : memref<4x1xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2> + return +} + +// in this example, only emit 1 atomic store +// CHECK: func @vector_store_i2_single_atomic( +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> + +// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> { +// CHECK: %[[ARG:.+]]: i8): +// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8 diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index 7401e470ed4f2..9a3fac623fbd7 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); - vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns); + vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns, + atomicStore); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); @@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass *this, "skip-memref-type-conversion", llvm::cl::desc("disable memref type conversion (to test failures)"), llvm::cl::init(false)}; + + Option atomicStore{ + *this, "atomic-store", + llvm::cl::desc("use atomic store instead of load-modify-write"), + llvm::cl::init(true)}; }; } // namespace