diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 95064083b21d4..7ca88f1e0a0df 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -45,6 +45,9 @@ using namespace mlir; #define DBGSNL() (llvm::dbgs() << "\n") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +using VectorValue = TypedValue; +using MemRefValue = TypedValue; + /// Returns a compressed mask for the emulated vector. For example, when /// emulating an eight-element `i8` vector with `i32` (i.e. when the source /// elements span two dest elements), this method compresses `vector<8xi1>` @@ -194,13 +197,10 @@ static FailureOr getCompressedMaskOp(OpBuilder &rewriter, /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for /// emitting `vector.extract_strided_slice`. static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, - VectorType extractType, Value source, - int64_t frontOffset, + Value source, int64_t frontOffset, int64_t subvecSize) { auto vectorType = cast(source.getType()); - assert((vectorType.getRank() == 1 && extractType.getRank() == 1) && - "expected 1-D source and destination types"); - (void)vectorType; + assert(vectorType.getRank() == 1 && "expected 1-D source types"); assert(frontOffset + subvecSize <= vectorType.getNumElements() && "subvector out of bounds"); @@ -211,9 +211,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, auto offsets = rewriter.getI64ArrayAttr({frontOffset}); auto sizes = rewriter.getI64ArrayAttr({subvecSize}); auto strides = rewriter.getI64ArrayAttr({1}); + + auto resultVectorType = + VectorType::get({subvecSize}, vectorType.getElementType()); return rewriter - .create(loc, extractType, source, offsets, - sizes, strides) + .create(loc, resultVectorType, source, + offsets, sizes, strides) ->getResult(0); } @@ -237,9 +240,10 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, /// function emits multiple `vector.extract` and `vector.insert` ops, so only /// use it when `offset` cannot be folded into a constant value. static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, - TypedValue source, - Value dest, OpFoldResult offset, + Value source, Value dest, + OpFoldResult offset, int64_t numElementsToExtract) { + assert(isa(source) && "expected `source` to be a vector type"); for (int i = 0; i < numElementsToExtract; ++i) { Value extractLoc = (i == 0) ? offset.dyn_cast() @@ -255,9 +259,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`. static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, - TypedValue source, - Value dest, OpFoldResult destOffsetVar, + Value source, Value dest, + OpFoldResult destOffsetVar, size_t length) { + assert(isa(source) && "expected `source` to be a vector type"); assert(length > 0 && "length must be greater than 0"); Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar); @@ -277,11 +282,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, /// specifically, use `emulatedElemType` for loading a vector of `origElemType`. /// The load location is given by `base` and `linearizedIndices`, and the /// load size is given by `numEmulatedElementsToLoad`. -static TypedValue -emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, - OpFoldResult linearizedIndices, - int64_t numEmultedElementsToLoad, Type origElemType, - Type emulatedElemType) { +static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, + Value base, + OpFoldResult linearizedIndices, + int64_t numEmultedElementsToLoad, + Type origElemType, + Type emulatedElemType) { auto scale = emulatedElemType.getIntOrFloatBitWidth() / origElemType.getIntOrFloatBitWidth(); auto newLoad = rewriter.create( @@ -292,6 +298,104 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, newLoad); } +/// Downcast two values to `downcastType`, then select values +/// based on `mask`, and casts the result to `upcastType`. +static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, + VectorType downcastType, + VectorType upcastType, Value mask, + Value trueValue, Value falseValue) { + assert( + downcastType.getNumElements() * downcastType.getElementTypeBitWidth() == + upcastType.getNumElements() * upcastType.getElementTypeBitWidth() && + "expected input and output number of bits to match"); + if (trueValue.getType() != downcastType) { + trueValue = builder.create(loc, downcastType, trueValue); + } + if (falseValue.getType() != downcastType) { + falseValue = + builder.create(loc, downcastType, falseValue); + } + Value selectedType = + builder.create(loc, mask, trueValue, falseValue); + // Upcast the selected value to the new type. + return builder.create(loc, upcastType, selectedType); +} + +/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a +/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of +/// subbyte-sized elements, with size of 8 bits, and the mask is used to select +/// which elements to store. +/// +/// Inputs: +/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>) +/// storeIdx = 2 +/// valueToStore = |3|3|3|3| : vector<4xi2> +/// mask = |0|0|1|1| : vector<4xi1> +/// +/// Result: +/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>) +static void atomicStore(OpBuilder &builder, Location loc, + MemRefValue linearizedMemref, Value storeIdx, + VectorValue valueToStore, Value mask) { + assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); + + // Create an atomic load-modify-write region using + // `memref.generic_atomic_rmw`. + auto atomicOp = builder.create( + loc, linearizedMemref, ValueRange{storeIdx}); + Value origValue = atomicOp.getCurrentValue(); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(atomicOp.getBody()); + + // Load the original value from memory, and cast it to the original element + // type. + auto oneElemVecType = VectorType::get({1}, origValue.getType()); + Value origVecValue = builder.create( + loc, oneElemVecType, ValueRange{origValue}); + + // Construct the final masked value and yield it. + Value maskedValue = + downcastSelectAndUpcast(builder, loc, valueToStore.getType(), + oneElemVecType, mask, valueToStore, origVecValue); + auto scalarMaskedValue = + builder.create(loc, maskedValue, 0); + builder.create(loc, scalarMaskedValue); +} + +/// Extract `sliceNumElements` from source `vector` at `extractOffset`, +/// and insert it into an empty vector at `insertOffset`. +/// Inputs: +/// vec_in = |0|1|2|3| : vector<4xi2> +/// extractOffset = 1 +/// sliceNumElements = 2 +/// insertOffset = 2 +/// Output: +/// vec_out = |0|0|1|2| : vector<4xi2> +static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, + Location loc, VectorValue vector, + int64_t extractOffset, + int64_t sliceNumElements, + int64_t insertOffset) { + assert(vector.getType().getRank() == 1 && "expected 1-D vector"); + auto vectorElementType = vector.getType().getElementType(); + // TODO: update and use `alignedConversionPrecondition` in the place of + // these asserts. + assert( + sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 && + "sliceNumElements * vector element size must be less than or equal to 8"); + assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 && + "vector element must be a valid sub-byte type"); + auto scale = 8 / vectorElementType.getIntOrFloatBitWidth(); + auto emptyByteVector = rewriter.create( + loc, VectorType::get({scale}, vectorElementType), + rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType))); + auto extracted = staticallyExtractSubvector(rewriter, loc, vector, + extractOffset, sliceNumElements); + return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector, + insertOffset); +} + namespace { //===----------------------------------------------------------------------===// @@ -311,9 +415,10 @@ struct ConvertVectorStore final : OpConversionPattern { "only 1-D vectors are supported ATM"); auto loc = op.getLoc(); - auto convertedType = cast(adaptor.getBase().getType()); - Type oldElementType = op.getValueToStore().getType().getElementType(); - Type newElementType = convertedType.getElementType(); + auto valueToStore = cast(op.getValueToStore()); + auto oldElementType = valueToStore.getType().getElementType(); + auto newElementType = + cast(adaptor.getBase().getType()).getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); int dstBits = newElementType.getIntOrFloatBitWidth(); @@ -321,7 +426,7 @@ struct ConvertVectorStore final : OpConversionPattern { return rewriter.notifyMatchFailure( op, "only dstBits % srcBits == 0 supported"); } - int scale = dstBits / srcBits; + int numSrcElemsPerDest = dstBits / srcBits; // Adjust the number of elements to store when emulating narrow types. // Here only the 1-D vector store is considered, and the N-D memref types @@ -336,15 +441,15 @@ struct ConvertVectorStore final : OpConversionPattern { // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>, // vector<4xi8> - auto origElements = op.getValueToStore().getType().getNumElements(); - if (origElements % scale != 0) - return failure(); + auto origElements = valueToStore.getType().getNumElements(); + bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); OpFoldResult linearizedIndices; - std::tie(std::ignore, linearizedIndices) = + memref::LinearizedMemRefInfo linearizedInfo; + std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( rewriter, loc, srcBits, dstBits, stridedMetadata.getConstifiedMixedOffset(), @@ -352,14 +457,165 @@ struct ConvertVectorStore final : OpConversionPattern { stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); - auto numElements = origElements / scale; - auto bitCast = rewriter.create( - loc, VectorType::get(numElements, newElementType), - op.getValueToStore()); + std::optional foldedNumFrontPadElems = + isAlignedEmulation + ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); + + if (!foldedNumFrontPadElems) { + return rewriter.notifyMatchFailure( + op, "subbyte store emulation: dynamic front padding size is " + "not yet implemented"); + } + + auto memrefBase = cast(adaptor.getBase()); + + // Conditions when atomic RMWs are not needed: + // 1. The source vector size (in bits) is a multiple of byte size. + // 2. The address of the store is aligned to the emulated width boundary. + // + // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not + // need unaligned emulation because the store address is aligned and the + // source is a whole byte. + bool emulationRequiresPartialStores = + !isAlignedEmulation || *foldedNumFrontPadElems != 0; + if (!emulationRequiresPartialStores) { + // Basic case: storing full bytes. + auto numElements = origElements / numSrcElemsPerDest; + auto bitCast = rewriter.create( + loc, VectorType::get(numElements, newElementType), + op.getValueToStore()); + rewriter.replaceOpWithNewOp( + op, bitCast.getResult(), memrefBase, + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); + return success(); + } + + // Next, handle the case when sub-byte read-modify-write + // sequences are needed to emulate a vector store. + // Here is an example: + // + // Vector to store: vector<7xi2> + // Value to store: 11 11 11 11 11 11 11 (all ones) + // + // Destination: memref<12xi2> + // Store offset: 2 (i.e. 4 bits into the 1st emulated byte). + // + // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2> + // + // Destination memref before: + // + // Byte 0 Byte 1 Byte 2 + // +----------+----------+----------+ + // | 00000000 | 00000000 | 00000000 | + // +----------+----------+----------+ + // + // Destination memref after: + // + // Byte 0 Byte 1 Byte 2 + // +----------+----------+----------+ + // | 00001111 | 11111111 | 11000000 | + // +----------+----------+----------+ + // + // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no + // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence + // requiring RMW access (atomicity is required). + + // The index into the target memref we are storing to. + Value currentDestIndex = + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); + // The index into the source vector we are currently processing. + auto currentSourceIndex = 0; + + // Build a mask used for rmw. + auto subWidthStoreMaskType = + VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); + + // 1. Partial width store for the leading byte. + // When the store address is not aligned to emulated width boundary, deal + // with the unaligned part so that the rest elements are aligned to width + // boundary. + auto frontSubWidthStoreElem = + (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest; + if (frontSubWidthStoreElem > 0) { + SmallVector frontMaskValues(numSrcElemsPerDest, false); + if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) { + std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems, + origElements, true); + frontSubWidthStoreElem = origElements; + } else { + std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem, + *foldedNumFrontPadElems, true); + } + auto frontMask = rewriter.create( + loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); + + currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems); + auto value = + extractSliceIntoByte(rewriter, loc, valueToStore, 0, + frontSubWidthStoreElem, *foldedNumFrontPadElems); - rewriter.replaceOpWithNewOp( - op, bitCast.getResult(), adaptor.getBase(), - getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); + atomicStore(rewriter, loc, memrefBase, currentDestIndex, + cast(value), frontMask.getResult()); + } + + if (currentSourceIndex >= origElements) { + rewriter.eraseOp(op); + return success(); + } + + // Increment the destination index by 1 to align to the emulated width + // boundary. + auto constantOne = rewriter.create(loc, 1); + currentDestIndex = rewriter.create( + loc, rewriter.getIndexType(), currentDestIndex, constantOne); + + // 2. Full width store for the inner output bytes. + // After the previous step, the store address is aligned to the emulated + // width boundary. + int64_t fullWidthStoreSize = + (origElements - currentSourceIndex) / numSrcElemsPerDest; + int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest; + if (fullWidthStoreSize > 0) { + auto fullWidthStorePart = staticallyExtractSubvector( + rewriter, loc, valueToStore, currentSourceIndex, + numNonFullWidthElements); + + auto originType = cast(fullWidthStorePart.getType()); + auto memrefElemType = getElementTypeOrSelf(memrefBase.getType()); + auto storeType = VectorType::get( + {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType); + auto bitCast = rewriter.create(loc, storeType, + fullWidthStorePart); + rewriter.create(loc, bitCast.getResult(), memrefBase, + currentDestIndex); + + currentSourceIndex += numNonFullWidthElements; + currentDestIndex = rewriter.create( + loc, rewriter.getIndexType(), currentDestIndex, + rewriter.create(loc, fullWidthStoreSize)); + } + + // 3. Partial width store for the trailing output byte. + // It is needed when the residual length is smaller than the emulated width, + // which is not covered in step 2 above. + auto remainingElements = origElements - currentSourceIndex; + if (remainingElements != 0) { + auto subWidthStorePart = + extractSliceIntoByte(rewriter, loc, cast(valueToStore), + currentSourceIndex, remainingElements, 0); + + // Generate back mask. + auto maskValues = SmallVector(numSrcElemsPerDest, 0); + std::fill_n(maskValues.begin(), remainingElements, 1); + auto backMask = rewriter.create( + loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); + + atomicStore(rewriter, loc, memrefBase, currentDestIndex, + cast(subWidthStorePart), backMask.getResult()); + } + + rewriter.eraseOp(op); return success(); } }; @@ -532,7 +788,7 @@ struct ConvertVectorLoad final : OpConversionPattern { // compile time as they must be constants. auto origElements = op.getVectorType().getNumElements(); - bool isUnalignedEmulation = origElements % scale != 0; + bool isAlignedEmulation = origElements % scale == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -548,9 +804,9 @@ struct ConvertVectorLoad final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isUnalignedEmulation - ? getConstantIntValue(linearizedInfo.intraDataOffset) - : 0; + isAlignedEmulation + ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); // Always load enough elements which can cover the original elements. int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); @@ -563,13 +819,12 @@ struct ConvertVectorLoad final : OpConversionPattern { if (!foldedIntraVectorOffset) { auto resultVector = rewriter.create( loc, op.getType(), rewriter.getZeroAttr(op.getType())); - result = dynamicallyExtractSubVector( - rewriter, loc, dyn_cast>(result), resultVector, - linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector, + linearizedInfo.intraDataOffset, + origElements); + } else if (!isAlignedEmulation) { + result = staticallyExtractSubvector( + rewriter, loc, result, *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); return success(); @@ -649,7 +904,7 @@ struct ConvertVectorMaskedLoad final // subvector at the proper offset after bit-casting. auto origType = op.getVectorType(); auto origElements = origType.getNumElements(); - bool isUnalignedEmulation = origElements % scale != 0; + bool isAlignedEmulation = origElements % scale == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -664,9 +919,9 @@ struct ConvertVectorMaskedLoad final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isUnalignedEmulation - ? getConstantIntValue(linearizedInfo.intraDataOffset) - : 0; + isAlignedEmulation + ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); FailureOr newMask = getCompressedMaskOp( @@ -685,9 +940,9 @@ struct ConvertVectorMaskedLoad final loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); if (!foldedIntraVectorOffset) { passthru = dynamicallyInsertSubVector( - rewriter, loc, dyn_cast>(passthru), - emptyVector, linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { + rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, + origElements); + } else if (!isAlignedEmulation) { passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset); } @@ -712,10 +967,10 @@ struct ConvertVectorMaskedLoad final auto emptyMask = rewriter.create( loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); if (!foldedIntraVectorOffset) { - mask = dynamicallyInsertSubVector( - rewriter, loc, dyn_cast>(mask), emptyMask, - linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { + mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, + linearizedInfo.intraDataOffset, + origElements); + } else if (!isAlignedEmulation) { mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, *foldedIntraVectorOffset); } @@ -724,12 +979,11 @@ struct ConvertVectorMaskedLoad final rewriter.create(loc, mask, bitCast, passthru); if (!foldedIntraVectorOffset) { result = dynamicallyExtractSubVector( - rewriter, loc, dyn_cast>(result), - op.getPassThru(), linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + rewriter, loc, result, op.getPassThru(), + linearizedInfo.intraDataOffset, origElements); + } else if (!isAlignedEmulation) { + result = staticallyExtractSubvector( + rewriter, loc, result, *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); @@ -769,7 +1023,7 @@ struct ConvertVectorTransferRead final auto origElements = op.getVectorType().getNumElements(); - bool isUnalignedEmulation = origElements % scale != 0; + bool isAlignedEmulation = origElements % scale == 0; auto newPadding = rewriter.create(loc, newElementType, adaptor.getPadding()); @@ -788,9 +1042,9 @@ struct ConvertVectorTransferRead final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isUnalignedEmulation - ? getConstantIntValue(linearizedInfo.intraDataOffset) - : 0; + isAlignedEmulation + ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); auto numElements = @@ -811,10 +1065,9 @@ struct ConvertVectorTransferRead final result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { - result = - staticallyExtractSubvector(rewriter, loc, op.getType(), result, - *foldedIntraVectorOffset, origElements); + } else if (!isAlignedEmulation) { + result = staticallyExtractSubvector( + rewriter, loc, result, *foldedIntraVectorOffset, origElements); } rewriter.replaceOp(op, result); diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 4332e80feed42..89cb8e0bde875 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -356,3 +356,142 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) // CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]] // CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2> // CHECK: return %[[RESULT]] : vector<5xi2> + +///---------------------------------------------------------------------------------------- +/// vector.store +///---------------------------------------------------------------------------------------- + +func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { + %src = memref.alloc() : memref<3x3xi2> + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + vector.store %arg0, %src[%c2, %c0] :memref<3x3xi2>, vector<3xi2> + return +} + +// In this example, emit 2 atomic RMWs. +// +// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes: +// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out) + +// CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores( +// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1> +// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2> + +// Part 1 atomic RMW sequence (load bits [12, 16) from %src_as_bytes[1]) +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> { +// CHECK: %[[ARG:.+]]: i8): +// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8 + +// Part 2 atomic RMW sequence (load bits [16, 18) from %src_as_bytes[2]) +// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index +// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2> +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST_0]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1> +// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<3xi8> { +// CHECK: %[[ARG2:.+]]: i8): +// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8> +// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8 + +// ----- + +func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) { + %0 = memref.alloc() : memref<3x7xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2> + return +} + +// In this example, emit 2 atomic RMWs and 1 non-atomic store: +// CHECK-LABEL: func @vector_store_i2_two_partial_one_full_stores( +// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> + +// First atomic RMW: +// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]] +// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> { +// CHECK: %[[ARG:.+]]: i8): +// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8 + +// Non-atomic store: +// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2> +// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8> +// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8> + +// Second atomic RMW: +// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index +// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]] +// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2> +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2> +// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> +// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> { +// CHECK: %[[ARG2:.+]]: i8): +// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8> +// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : +// CHECK-SAME: vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8 + +// ----- + +func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) { + %0 = memref.alloc() : memref<4x1xi2> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + vector.store %arg0, %0[%c1, %c0] : memref<4x1xi2>, vector<1xi2> + return +} + +// In this example, only emit 1 atomic store +// CHECK-LABEL: func @vector_store_i2_const_index_one_partial_store( +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1> +// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]] +// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2> + +// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> { +// CHECK: %[[ARG:.+]]: i8): +// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8