Skip to content

Commit

Permalink
Pass to align transfer_reads (#867)
Browse files Browse the repository at this point in the history
This PR adds the necessary pass to align `transfer_read`. It is based on
mlir-aie's aievec, see:


https://github.com/Xilinx/mlir-aie/blob/d3da586305ebc22e5ecdf1d3e682b44853436e91/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp#L123

Some changes were needed for our use case, however. The main one is that
the lowering in this PR skips the `vector.extract_strided_slice`
operation, because we have an offset which is not constant. i.e. the
offsets in
https://mlir.llvm.org/docs/Dialects/Vector/#vectorextract_strided_slice-vectorextractstridedsliceop
cannot be integers for us, because they are determined from loop
induction variables. The pass implemented here goes straight to aievec
extract and shift operations, where mlir Values are used for offsets.

Also included in this PR: an aievec.shift folder. I can make this a
separate PR if preferred.

This PR enables vectorization for convolution and resolves
#820
  • Loading branch information
newling authored Nov 5, 2024
1 parent fded307 commit 2086718
Show file tree
Hide file tree
Showing 8 changed files with 446 additions and 13 deletions.
3 changes: 3 additions & 0 deletions compiler/plugins/target/AMD-AIE/aievec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ iree_cc_library(
MLIREmitCDialect
::AIEVecDialectIR
::AIEVecXLLVMOpsGen
iree-amd-aie::aie_runtime::iree_aie_runtime_static
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
)

add_subdirectory(test)
58 changes: 54 additions & 4 deletions compiler/plugins/target/AMD-AIE/aievec/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
namespace mlir::iree_compiler::aievec {

/**
* Append pass(es) for canonicalizing operations in the vector dialect to a form
* Append passes for canonicalizing operations in the vector dialect to a form
* that can be lowered to the AIEVec dialect.
*/
void buildCanonicalizeVectorForAIEVec(mlir::OpPassManager &);

/**
* A pass containing patterns for canonicalizing operations in the vector
* A pass containing some patterns for canonicalizing operations in the vector
* dialect to a form that can be lowered to the AIEVec dialect. This pass is
* named `canonicalize-vector-for-aievec`.
* named `canonicalize-vector-for-aievec`. To ensure all required vector dialect
* canonicalizations take place, PassManagers should use
* `buildCanonicalizeVectorForAIEVec`.
*/
std::unique_ptr<mlir::Pass> createCanonicalizeVectorForAIEVecPass();

Expand All @@ -39,6 +41,54 @@ std::unique_ptr<mlir::Pass> createCanonicalizeVectorForAIEVecPass();
*/
void registerCanonicalizeVectorForAIEVecPass();

/**
* This pass ensures that reads from AIE tile memory are aligned according to
* hardware constraints. For example, suppose we have 128 bytes in tile memory,
* represented in hex as:
*
* 0x00 0x01 ... 0x7E 0x7F
*
* On AIE-2, the (vector) read instructions from the tile memory into registers
* must be aligned to 256-bits (32-bytes). So if we want to read 64 bytes
* starting from 0x00 that is fine, but if we want to read 64 bytes starting
* from 0x01, then we cannot use a vector read instruction directly. To work
* around this constraint, we do the following:
*
* 1. Perform a wider read, that loads 128 bytes (2x as many as we want)
* starting from 0x00 into a larger register. That is, bytes 0x00-0x7F are
* loaded, so we have 1 'junk' byte at the beginning and 63 'junk' bytes at
* the end.
*
* 2. Extract the target bytes 0x01 ... 0x40 from the larger register into a
* smaller register in 2 steps, using 2 AIE specific instructions:
*
* a) Extract:
* https://www.xilinx.com/htmldocs/xilinx2023_2/aiengine_ml_intrinsics/intrinsics/group__intr__gpvectorconv__elem.html
*
* b) Shift:
* https://www.xilinx.com/htmldocs/xilinx2023_2/aiengine_ml_intrinsics/intrinsics/group__intr__gpvectorop__shift.html
*
* First, we use the extract instruction to split the read 128-bytes into two
* halves, 0x00-0x3F and 0x40-0x7F, each in its own 64-byte register. Then, we
* use a shift operation to combine the upper 31 bytes from the first half
* and the lower 33 bytes from the second half into a new 64-byte register.
* This new register contains exactly the 64 bytes we want to read, starting
* from 0x01.
*
* If we want to read 32 bytes starting from 0x01, we can use a similar
* approach. The only consideration is that the shift operation requires 64-byte
* inputs, so the order of the of the shift and extracts is reversed.
*
* We do not currently support unaligned reads of vectors which are not 32-bytes
* or 64-bytes in length.
*
* TODO(newling) use this same approach to align writes to unaligned memory.
* */

std::unique_ptr<mlir::Pass> createAlignTransferReadsPass();

void registerAlignTransferReadsPass();

/**
* Append pass(es) for lowering operations in the vector dialect to the AIEVec
* dialect. Vector dialect ops are expected to be in a canonical form
Expand All @@ -48,7 +98,7 @@ void buildLowerVectorToAIEVec(mlir::OpPassManager &pm);

/**
* A pass containing patterns for lowering operations in the vector dialect to
* the AIEVec dialect. The pass is currently named `test-lower-vector-to-aievec`.
* the AIEVec dialect. The pass is currently named `test-lower-vector-to-aievec`
*/
static std::unique_ptr<mlir::Pass> createLowerVectorToAIEVec();

Expand Down
184 changes: 181 additions & 3 deletions compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#include <memory>

#include "Passes.h"
#include "aievec/AIEVecOps.h"
#include "iree-amd-aie/Transforms/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
Expand Down Expand Up @@ -300,7 +303,6 @@ class FlattenContiguousRowMajorTransferWritePattern

} // namespace copied_from_mlir


static bool isGemmBTransposedContractionOp(vector::ContractionOp op) {
if (op.getKind() != vector::CombiningKind::ADD) return false;

Expand Down Expand Up @@ -897,6 +899,7 @@ struct CanonicalizeVectorForAIEVecPass
populateBubbleSignExtensionsLate(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}

{
RewritePatternSet patterns(context);
patterns
Expand All @@ -914,6 +917,7 @@ struct CanonicalizeVectorForAIEVecPass
mlir::vector::populateVectorBroadcastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}

{
// These must run after 'populateFlattenVectorTransferPatterns' because
// vector.shape_casts are introduced. Merging into a single pass creates
Expand All @@ -925,6 +929,170 @@ struct CanonicalizeVectorForAIEVecPass
}
};

/// Returns one of:
/// 1) failure, if there is definitely an error that should be propagated.
/// 2) a new transfer_read operation that is sufficiently aligned, if the old
/// transfer_read is determined to be insufficiently aligned and it is
/// possible to create a new transfer_read.
/// 3) the original transfer_read operation, otherwise.
FailureOr<Value> getAlignedTransferRead(
vector::TransferReadOp readOp, IRRewriter &rewriter,
const AMDAIE::AMDAIEDeviceModel &deviceModel) {
uint32_t vectorLoadStoreAlignmentBits =
deviceModel.getVectorLoadStoreAlignmentBits();
uint32_t maxVectorSizeBits = deviceModel.getMaxVectorSizeBits();
uint32_t shiftOperandBits = deviceModel.getShiftOperandBits();

// Check that it's not a splat transfer read.
if (readOp.getPermutationMap().isConstant()) return readOp.getVector();

MLIRContext *ctx = readOp.getContext();
VectorType shortType = readOp.getVectorType();
Location loc = readOp.getLoc();
Value padding = readOp.getPadding();
ShapedType sourceType = readOp.getSource().getType();
Type elementType = shortType.getElementType();

if (sourceType.getRank() != 1 || shortType.getRank() != 1) {
return readOp.emitOpError(
"does not have rank-1 source and rank-1 vector type.");
}

uint32_t elementBits = elementType.getIntOrFloatBitWidth();
int64_t shortLength = shortType.getShape().back();
int64_t shortBits = shortLength * elementBits;
uint32_t alignElements = vectorLoadStoreAlignmentBits / elementBits;

rewriter.setInsertionPoint(readOp);

AffineMap moduloMap =
AffineMap::get(1, 0, getAffineDimExpr(0, ctx) % alignElements);

Value oldIndex = readOp.getIndices().back();

Value offset = rewriter.createOrFold<affine::AffineApplyOp>(
loc, moduloMap, SmallVector<Value, 1>{oldIndex});

// If the offset is constant and zero, the read is already aligned.
if (auto offsetConstantOp = offset.getDefiningOp<arith::ConstantIndexOp>())
if (offsetConstantOp.getValue() == 0) return readOp.getVector();

// Verify that we can load a vector 2x as long as the original vector.
int64_t longBits = 2 * shortBits;
int64_t longLength = 2 * shortLength;
VectorType longType = VectorType::get(longLength, elementType);
if (longBits > maxVectorSizeBits) {
// Not returning failure, as it is possible that the read is already
// aligned, and we just couldn't prove it.
readOp.emitWarning()
<< "`transfer_read` can't be aligned with a read twice "
<< "as large because " << longBits
<< " bits is greater than the maximum vector size of "
<< maxVectorSizeBits << " bits.";

return readOp.getVector();
}

SmallVector<bool> inBounds = readOp.getInBoundsValues();
bool allInBounds =
std::all_of(inBounds.begin(), inBounds.end(), [](bool b) { return b; });

if (shortBits != shiftOperandBits / 2 && shortBits != shiftOperandBits) {
// Not returning failure, as it is possible that the read is already
// aligned, and we just couldn't prove it.
readOp.emitWarning() << "`transfer_read` doesn't have a vector with "
<< shiftOperandBits / 2 << " or " << shiftOperandBits
<< " bits."
<< "This case is not currently handled.";
return readOp.getVector();
}

Value newIndex = rewriter.createOrFold<arith::SubIOp>(loc, oldIndex, offset);

// Create the aligned transfer read for a vector 2x as long that covers the
// elements of the unaligned vector.
Value longVec = rewriter.create<vector::TransferReadOp>(
loc, longType, readOp.getSource(), SmallVector<Value>{newIndex}, padding,
SmallVector<bool>{allInBounds});

Value elementBytes =
rewriter.create<arith::ConstantIndexOp>(loc, elementBits / 8);

Value offsetBytes =
rewriter.createOrFold<arith::MulIOp>(loc, offset, elementBytes);

Value offsetBytes_i32 = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIntegerType(32), offsetBytes);

Value replacement;
if (shortBits == shiftOperandBits) {
// - Extract lower 64 bytes
// - Extract upper 64 bytes
// - Apply shift to obtain new 64 bytes
Value low = rewriter.create<ExtOp>(loc, shortType, longVec,
rewriter.getI8IntegerAttr(0));
Value upp = rewriter.create<ExtOp>(loc, shortType, longVec,
rewriter.getI8IntegerAttr(1));
replacement = rewriter.createOrFold<ShiftOp>(loc, shortType, low, upp,
offsetBytes_i32);
} else if (shortBits == shiftOperandBits / 2) {
// - Apply shift to obtain new 64 bytes, bottom 32 being the required ones
// - Extract lower 32 bytes
Value shift = rewriter.createOrFold<ShiftOp>(loc, longType, longVec,
longVec, offsetBytes_i32);
replacement = rewriter.create<ExtOp>(loc, shortType, shift,
rewriter.getI8IntegerAttr(0));
} else {
assert(false &&
"unreachable: already checked that shortBytes is equal to or half "
"of shiftOperandBytes");
}

rewriter.replaceOp(readOp, replacement);

return replacement;
}

struct AlignTransferReadsPass
: public PassWrapper<AlignTransferReadsPass, OperationPass<>> {
StringRef getArgument() const final { return "align-transfer-reads"; }

StringRef getDescription() const final {
return "Align `vector.transfer_read` operations.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<arith::ArithDialect, memref::MemRefDialect,
vector::VectorDialect, affine::AffineDialect, AIEVecDialect>();
}

void runOnOperation() override {
Operation *op = getOperation();

auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
std::optional<AMDAIE::AMDAIEDevice> maybeDevice =
mlir::iree_compiler::AMDAIE::getConfigAMDAIEDevice(targetAttr);
if (!maybeDevice) {
op->emitOpError()
<< "has no AMDAIEDevice in the target attribute configuration. This "
"device-specific information is required to determine what vector "
"sizes and alignments are supported.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());

IRRewriter rewriter(&getContext());
op->walk([&](vector::TransferReadOp transferReadOp) {
if (failed(
getAlignedTransferRead(transferReadOp, rewriter, deviceModel))) {
signalPassFailure();
}
});
}
};

struct DetectNonCanonicalOpsPass
: public PassWrapper<DetectNonCanonicalOpsPass, OperationPass<>> {
StringRef getArgument() const final {
Expand All @@ -943,7 +1111,7 @@ struct DetectNonCanonicalOpsPass
}

void runOnOperation() override {
auto op = getOperation();
Operation *op = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
ConversionTarget target(*context);
Expand All @@ -955,8 +1123,8 @@ struct DetectNonCanonicalOpsPass
};

void buildCanonicalizeVectorForAIEVec(OpPassManager &pm) {
// TODO: Add passes to split vectors that won't fit in registers
pm.addPass(createCanonicalizeVectorForAIEVecPass());
pm.addPass(createAlignTransferReadsPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(std::make_unique<DetectNonCanonicalOpsPass>());
}
Expand All @@ -971,4 +1139,14 @@ void registerCanonicalizeVectorForAIEVecPass() {
});
}

std::unique_ptr<::mlir::Pass> createAlignTransferReadsPass() {
return std::make_unique<AlignTransferReadsPass>();
}

void registerAlignTransferReadsPass() {
::mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return std::make_unique<AlignTransferReadsPass>();
});
}

} // namespace mlir::iree_compiler::aievec
11 changes: 10 additions & 1 deletion compiler/plugins/target/AMD-AIE/aievec/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
file(GLOB _mlir_files *.mlir)
set(_mlir_files
align-transfer-reads.mlir
fold-ops.mlir
matmul.mlir
precanonicalization-aieml-llvmir.mlir
test-mac_elem.mlir
test-shuffle.mlir
test-srs.mlir
test-ups.mlir
)

iree_lit_test_suite(
NAME
Expand Down
Loading

0 comments on commit 2086718

Please sign in to comment.