Skip to content

Commit

Permalink
Lower GEMM to BRGEMM kernel (triton-lang#13)
Browse files Browse the repository at this point in the history
Extends contraction lowering to XSMM by rewriting plain GEMM into
a BRGEMM kernel when possible.

The rewrite improves performance of larger K block sizes thanks to
extra reduction dim tiling. Use of BRGEMM kernel also enables online
VNNI packing for BF16.
  • Loading branch information
adam-smnk authored and Devjiu committed Nov 13, 2024
1 parent e6c26cb commit 4f37c9a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
7 changes: 4 additions & 3 deletions third_party/cpu/include/Xsmm/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::Modul
let description = [{
Convert vector operations to XSMM operations.
}];
let dependentDialects = ["func::FuncDialect",
let dependentDialects = ["arith::ArithDialect",
"func::FuncDialect",
"memref::MemRefDialect",
"vector::VectorDialect",
"LLVM::LLVMDialect"];
Expand All @@ -19,9 +20,9 @@ def ConvertTritonToXsmm : Pass<"triton-cpu-convert-triton-to-xsmm", "mlir::Modul
let description = [{
Convert triton operations to XSMM operations.
}];
let dependentDialects = ["func::FuncDialect",
let dependentDialects = ["arith::ArithDialect",
"func::FuncDialect",
"memref::MemRefDialect",
"triton::TritonDialect",
"triton::cpu::TritonCPUDialect",
"LLVM::LLVMDialect"];
}
Expand Down
29 changes: 29 additions & 0 deletions third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "VnniUtils.h"
#include "XsmmUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -222,6 +223,34 @@ struct DotToXsmm : public OpRewritePattern<triton::DotOp> {
SmallVector<Value> inputs{lhsBuf, rhsBuf, accBuf};
SmallVector<Value> outputs{nullptr};

// Rewrite matmul into a BRGEMM.
// This allows for additional reduction dimension tiling driven
// by a microkernel.
//
// TODO: Expand heuristics about brgemm rewrite profitability.
// TODO: Allow for batch dimension.
int64_t kDim = lhs.getType().getShape().back();
auto accShape = acc.getType().getShape();
constexpr int64_t kTile = 32;
int64_t numTiles = kDim / kTile;
if (rank == 2 && (kDim % kTile) == 0 && numTiles > 1) {
// Split reduction dimension into tiles.
// The number of tiles represents the batch dimension.
inputs[0] = rewriter.create<memref::ExpandShapeOp>(
loc, SmallVector<int64_t>{accShape[0], numTiles, kTile}, inputs[0],
SmallVector<ReassociationIndices>{{0}, {1, 2}});
inputs[1] = rewriter.create<memref::ExpandShapeOp>(
loc, SmallVector<int64_t>{numTiles, kTile, accShape[1]}, inputs[1],
SmallVector<ReassociationIndices>{{0, 1}, {2}});

// Update maps with BRGEMM indexing.
auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx);
auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx);
auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx);
indexingMaps = SmallVector<AffineMap>{mapA, mapB, mapC};
}

// TODO: Perform this check much earlier before any rewrites.
auto brgemmInfo = xsmm::utils::isMappableToBrgemm(rewriter, dotOp, inputs,
outputs, indexingMaps);
if (failed(brgemmInfo)) {
Expand Down
49 changes: 39 additions & 10 deletions third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,14 @@ struct ContractToXsmm : public OpRewritePattern<vector::ContractionOp> {
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
Location loc = contractOp.getLoc();
MLIRContext *ctx = contractOp.getContext();

TypedValue<VectorType> lhs = contractOp.getLhs();
TypedValue<VectorType> rhs = contractOp.getRhs();
TypedValue<Type> acc = contractOp.getAcc();

auto vecTy = dyn_cast<VectorType>(acc.getType());
if (!vecTy)
auto accVecTy = dyn_cast<VectorType>(acc.getType());
if (!accVecTy)
return rewriter.notifyMatchFailure(contractOp,
"expects to accumulate on vector");

Expand All @@ -177,17 +178,45 @@ struct ContractToXsmm : public OpRewritePattern<vector::ContractionOp> {

SmallVector<Value> inputs{lhsBuf, rhsBuf, accBuf};
SmallVector<Value> outputs{nullptr};
auto brgemmInfo =
xsmm::utils::isMappableToBrgemm(rewriter, contractOp, inputs, outputs,
contractOp.getIndexingMapsArray());
SmallVector<AffineMap> indexingMaps = contractOp.getIndexingMapsArray();

// Rewrite matmul into a BRGEMM.
// This allows for additional reduction dimension tiling driven
// by a microkernel.
//
// TODO: Expand heuristics about brgemm rewrite profitability.
// TODO: Allow for batch dimension.
int64_t kDim = lhs.getType().getShape().back();
auto accShape = accVecTy.getShape();
constexpr int64_t kTile = 32;
int64_t numTiles = kDim / kTile;
uint32_t rank = accVecTy.getRank();
if (rank == 2 && (kDim % kTile) == 0 && numTiles > 1) {
// Split reduction dimension into tiles.
// The number of tiles represents the batch dimension.
inputs[0] = rewriter.create<memref::ExpandShapeOp>(
loc, SmallVector<int64_t>{accShape[0], numTiles, kTile}, inputs[0],
SmallVector<ReassociationIndices>{{0}, {1, 2}});
inputs[1] = rewriter.create<memref::ExpandShapeOp>(
loc, SmallVector<int64_t>{numTiles, kTile, accShape[1]}, inputs[1],
SmallVector<ReassociationIndices>{{0, 1}, {2}});

// Update maps with BRGEMM indexing.
auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx);
auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx);
auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx);
indexingMaps = SmallVector<AffineMap>{mapA, mapB, mapC};
}

auto brgemmInfo = xsmm::utils::isMappableToBrgemm(
rewriter, contractOp, inputs, outputs, indexingMaps);
if (failed(brgemmInfo)) {
assert(false); // FIXME: getMemrefSource above already modified IR...
// return rewriter.notifyMatchFailure(contractOp, "not mappable to XSMM");
}

auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(
rewriter, contractOp, ValueRange{lhsBuf, rhsBuf, accBuf},
contractOp.getIndexingMapsArray(), flags);
auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, contractOp, inputs,
indexingMaps, flags);

if (hoistedAcc) {
// Hoisting already updated all uses correctly.
Expand All @@ -198,8 +227,8 @@ struct ContractToXsmm : public OpRewritePattern<vector::ContractionOp> {
Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(
dyn_cast<MemRefType>(accBuf.getType()).getRank(), zeroIdx);
auto readOp =
rewriter.create<vector::TransferReadOp>(loc, vecTy, accBuf, indices);
auto readOp = rewriter.create<vector::TransferReadOp>(loc, accVecTy,
accBuf, indices);
rewriter.replaceOp(contractOp, readOp);
}

Expand Down

0 comments on commit 4f37c9a

Please sign in to comment.