Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Support non-atomic RMW option for emulated vector stores #124887

Merged
merged 8 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

/// Appends patterns for emulating vector operations over narrow types with ops
/// over wider types.
/// over wider types. The `disableAtomicRMW` indicates whether to use a normal
/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
/// perform subbyte storing.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns, bool disableAtomicRMW = false);

/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
Expand Down
54 changes: 45 additions & 9 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
///
/// Result:
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value storeIdx,
VectorValue valueToStore, Value mask) {
static void atomicRMW(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value storeIdx,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

// Create an atomic load-modify-write region using
Expand All @@ -363,6 +363,27 @@ static void atomicStore(OpBuilder &builder, Location loc,
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}

/// Generate a non-atomic read-modify-write sequence for storing to the emulated
/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
static void nonAtomicRMW(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");

auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
Value origVecValue = builder.create<vector::LoadOp>(
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
origVecValue);

Value maskedValue =
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
oneElemVecType, mask, valueToStore, origVecValue);
builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
linearizedIndex);
}

/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
/// and insert it into an empty vector at `insertOffset`.
/// Inputs:
Expand Down Expand Up @@ -405,6 +426,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;

ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
: OpConversionPattern<vector::StoreOp>(context),
disableAtomicRMW(disableAtomicRMW) {}

LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -531,6 +556,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto subWidthStoreMaskType =
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());

auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;

// 1. Partial width store for the leading byte.
// When the store address is not aligned to emulated width boundary, deal
// with the unaligned part so that the rest elements are aligned to width
Expand All @@ -555,8 +582,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);

atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
}

if (currentSourceIndex >= origElements) {
Expand Down Expand Up @@ -611,13 +638,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));

atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart), backMask.getResult());
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart), backMask.getResult());
}

rewriter.eraseOp(op);
return success();
}

private:
const bool disableAtomicRMW;
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1930,12 +1960,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {

void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
RewritePatternSet &patterns, bool disableAtomicRMW) {

// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());

// Populate `vector.*` store conversion patterns. The caller can choose
// to avoid emitting atomic operations and reduce it to read-modify-write
// sequence for stores if it is known there are no thread contentions.
patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
}

void vector::populateVectorNarrowTypeRewritePatterns(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s

// TODO: remove memref.alloc() in the tests to eliminate noises.
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.

///----------------------------------------------------------------------------------------
/// vector.store
///----------------------------------------------------------------------------------------

func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}

// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
// both bytes are accessed partially through masking.

// CHECK: func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index

// Part 1 RMW sequence
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]]
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]]

// Part 2 RMW sequence
// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
// CHECK: %[[LOAD2:.+]] = vector.load
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]


// -----

func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
return
}

// In this example, emit two RMW stores and one full-width store.

// CHECK: func @vector_store_i2_two_partial_one_full_stores(
// CHECK-SAME: %[[ARG0:.+]]:
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [3], strides = [1]}
// First sub-width RMW:
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]

// Full-width store:
// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]

// Second sub-width RMW:
// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
// CHECK-SAME: {offsets = [0], strides = [1]}
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]]
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]]

// -----

func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) {
%0 = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
return
}

// in this test, only emit partial RMW store as the store is within one byte.

// CHECK: func @vector_store_i2_one_partial_store(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,9 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
return
}

// In this example, emit 2 atomic RMWs.
//
// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
// Emit two atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
// both bytes are accessed partially through masking.

// CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
Expand Down
9 changes: 8 additions & 1 deletion mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass

arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
disableAtomicRMW);

if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
Expand All @@ -118,6 +119,12 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};

Option<bool> disableAtomicRMW{
*this, "disable-atomic-rmw",
llvm::cl::desc("disable atomic read-modify-write and prefer generating "
"normal sequence"),
llvm::cl::init(false)};
};
} // namespace

Expand Down