Skip to content

Commit

Permalink
Implement vector stores
Browse files Browse the repository at this point in the history
  • Loading branch information
lialan committed Nov 26, 2024
1 parent 2ed8c5d commit 72bdea0
Show file tree
Hide file tree
Showing 5 changed files with 436 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Appends patterns for emulating vector operations over narrow types with ops
/// over wider types.
/// over wider types. `useAtomicWrites` indicates whether to use atomic
/// operations in the places where thread contention is possible.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns, bool useAtomicWrites = true);

/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
Expand Down
265 changes: 232 additions & 33 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
Expand Down Expand Up @@ -211,13 +212,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
VectorType extractType, Value source,
int64_t frontOffset,
Value source, int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
"expected 1-D source and destination types");
(void)vectorType;
assert(vectorType.getRank() == 1 && "expected 1-D source types");
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
"subvector out of bounds");

Expand All @@ -228,9 +226,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});

auto resultVectorType =
VectorType::get({subvecSize}, vectorType.getElementType());
return rewriter
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
sizes, strides)
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
offsets, sizes, strides)
->getResult(0);
}

Expand Down Expand Up @@ -309,6 +310,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}

/// Atomically store a subbyte-sized value to memory, with a mask.
static void atomicStore(OpBuilder &builder, Location loc,
TypedValue<MemRefType> emulatedMemref,
Value linearizedIndex, TypedValue<VectorType> value,
Value mask, int64_t) {
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
loc, emulatedMemref, ValueRange{linearizedIndex});
Value origValue = atomicOp.getCurrentValue();

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(atomicOp.getBody());

// i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
auto oneVectorType = VectorType::get({1}, origValue.getType());
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
ValueRange{origValue});
auto vectorBitCast =
builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);

auto select =
builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
}

/// Generate a non-atomic read-modify-write sequence for subbyte storing.
static void rmwStore(OpBuilder &rewriter, Location loc,
TypedValue<MemRefType> emulatedMemref,
Value linearizedIndex, TypedValue<VectorType> value,
Value mask, int64_t numSrcElemsPerDest) {
auto emulatedIOType =
VectorType::get({1}, emulatedMemref.getType().getElementType());
auto elemLoad = rewriter.create<vector::LoadOp>(
loc, emulatedIOType, emulatedMemref, ValueRange{linearizedIndex});
auto fromBitcast = rewriter.create<vector::BitCastOp>(
loc,
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
elemLoad);
auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
auto toBitcast =
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
linearizedIndex);
}

static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
"`atomicStore` and `rmwStore` must have same signature, as per "
"the design to keep the code clean, which one to call is "
"determined by the `useAtomicWrites` flag.");

// Extract a slice of a vector, and insert it into a byte vector.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, TypedValue<VectorType> vector,
int64_t sliceOffset, int64_t sliceNumElements,
int64_t byteOffset) {
auto vectorElementType = vector.getType().getElementType();
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
"vector element must be a valid sub-byte type");
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
loc, VectorType::get({scale}, vectorElementType),
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
sliceOffset, sliceNumElements);
auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
emptyByteVector, byteOffset);
return inserted;
}

namespace {

//===----------------------------------------------------------------------===//
Expand All @@ -318,6 +389,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;

ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
: OpConversionPattern<vector::StoreOp>(context),
useAtomicWrites_(useAtomicWrites) {}

LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -329,16 +404,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
auto valueToStore = cast<TypedValue<VectorType>>(op.getValueToStore());
auto oldElementType = valueToStore.getType().getElementType();
auto newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();

if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
int scale = dstBits / srcBits;
int numSrcElemsPerDest = dstBits / srcBits;

// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
Expand All @@ -353,32 +429,153 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>

auto origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
auto origElements = valueToStore.getType().getNumElements();
bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());

OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));

auto numElements = origElements / scale;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
auto foldedNumFrontPadElems =
isUnalignedEmulation
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
if (!foldedNumFrontPadElems) {
// Unimplemented case for dynamic front padding size != 0
return failure();
}

auto emulatedMemref = cast<TypedValue<MemRefType>>(adaptor.getBase());

// Shortcut: conditions when subbyte store at the front is not needed:
// 1. The source vector size is multiple of byte size
// 2. The address of the store is aligned to the emulated width boundary
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), emulatedMemref,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return success();
}

// The index into the target memref we are storing to
Value currentDestIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto subWidthStoreMaskType =
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
// The index into the source vector we are currently processing
auto currentSourceIndex = 0;

// 1. Partial width store for the first byte, when the store address is not
// aligned to emulated width boundary, deal with the unaligned part so that
// the rest elements are aligned to width boundary.
auto frontSubWidthStoreElem =
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
if (frontSubWidthStoreElem != 0) {
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontSubWidthStoreElem = origElements;
} else {
std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
*foldedNumFrontPadElems, true);
}
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));

currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
auto value =
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);

subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
cast<TypedValue<VectorType>>(value),
frontMask.getResult(), numSrcElemsPerDest);

currentDestIndex = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
}

if (currentSourceIndex >= origElements) {
rewriter.eraseOp(op);
return success();
}

// 2. Full width store. After the previous step, the store address is
// aligned to the emulated width boundary.
int64_t fullWidthStoreSize =
(origElements - currentSourceIndex) / numSrcElemsPerDest;
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
if (fullWidthStoreSize != 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
numNonFullWidthElements);

auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
auto memrefElemType =
dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
auto storeType = VectorType::get(
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
currentDestIndex);

currentSourceIndex += numNonFullWidthElements;
currentDestIndex = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), currentDestIndex,
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
}

// 3. Deal with trailing elements that are aligned to the emulated width,
// but their length is smaller than the emulated width.
auto remainingElements = origElements - currentSourceIndex;
if (remainingElements != 0) {
auto subWidthStorePart = extractSliceIntoByte(
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
currentSourceIndex, remainingElements, 0);

// Generate back mask
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));

subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
cast<TypedValue<VectorType>>(subWidthStorePart),
backMask.getResult(), numSrcElemsPerDest);
}

rewriter.eraseOp(op);
return success();
}

/// Store a subbyte-sized value to memory, with a mask. Depending on the
/// configuration, it could be an atomic store or an RMW sequence.
template <typename... Args>
void subEmulatedWidthStore(Args &&...args) const {
std::function<decltype(atomicStore)> storeFunc =
useAtomicWrites_ ? atomicStore : rmwStore;
storeFunc(std::forward<Args>(args)...);
}

private:
const bool useAtomicWrites_;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -584,9 +781,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -745,9 +941,8 @@ struct ConvertVectorMaskedLoad final
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);

Expand Down Expand Up @@ -830,9 +1025,8 @@ struct ConvertVectorTransferRead final
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);

Expand Down Expand Up @@ -1577,12 +1771,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {

void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
RewritePatternSet &patterns, bool useAtomicWrites) {

// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
// Populate `vector.*` load conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());

// Populate `vector.*` store conversion patterns. The caller can choose
// to avoid emitting atomic operations and reduce it to load-modify-write
// sequence for stores if it is known there are no thread contentions.
patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}

void vector::populateVectorNarrowTypeRewritePatterns(
Expand Down
Loading

0 comments on commit 72bdea0

Please sign in to comment.