From 4f37c9a961b86737a231530eeb31cc6218727534 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Fri, 1 Nov 2024 12:56:30 +0100 Subject: [PATCH] Lower GEMM to BRGEMM kernel (#13) Extends contraction lowering to XSMM by rewriting plain GEMM into a BRGEMM kernel when possible. The rewrite improves performance of larger K block sizes thanks to extra reduction dim tiling. Use of BRGEMM kernel also enables online VNNI packing for BF16. --- third_party/cpu/include/Xsmm/Passes.td | 7 +-- .../cpu/lib/Xsmm/ConvertTritonToXsmm.cpp | 29 +++++++++++ .../cpu/lib/Xsmm/ConvertVectorToXsmm.cpp | 49 +++++++++++++++---- 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/third_party/cpu/include/Xsmm/Passes.td b/third_party/cpu/include/Xsmm/Passes.td index b07fb93cc72ab..5527c233de7ac 100644 --- a/third_party/cpu/include/Xsmm/Passes.td +++ b/third_party/cpu/include/Xsmm/Passes.td @@ -8,7 +8,8 @@ def ConvertVectorToXsmm : Pass<"triton-cpu-convert-vector-to-xsmm", "mlir::Modul let description = [{ Convert vector operations to XSMM operations. }]; - let dependentDialects = ["func::FuncDialect", + let dependentDialects = ["arith::ArithDialect", + "func::FuncDialect", "memref::MemRefDialect", "vector::VectorDialect", "LLVM::LLVMDialect"]; @@ -19,9 +20,9 @@ def ConvertTritonToXsmm : Pass<"triton-cpu-convert-triton-to-xsmm", "mlir::Modul let description = [{ Convert triton operations to XSMM operations. }]; - let dependentDialects = ["func::FuncDialect", + let dependentDialects = ["arith::ArithDialect", + "func::FuncDialect", "memref::MemRefDialect", - "triton::TritonDialect", "triton::cpu::TritonCPUDialect", "LLVM::LLVMDialect"]; } diff --git a/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp index c77362e9ab5df..d2e7a3522ca42 100644 --- a/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp +++ b/third_party/cpu/lib/Xsmm/ConvertTritonToXsmm.cpp @@ -12,6 +12,7 @@ #include "VnniUtils.h" #include "XsmmUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -222,6 +223,34 @@ struct DotToXsmm : public OpRewritePattern { SmallVector inputs{lhsBuf, rhsBuf, accBuf}; SmallVector outputs{nullptr}; + // Rewrite matmul into a BRGEMM. + // This allows for additional reduction dimension tiling driven + // by a microkernel. + // + // TODO: Expand heuristics about brgemm rewrite profitability. + // TODO: Allow for batch dimension. + int64_t kDim = lhs.getType().getShape().back(); + auto accShape = acc.getType().getShape(); + constexpr int64_t kTile = 32; + int64_t numTiles = kDim / kTile; + if (rank == 2 && (kDim % kTile) == 0 && numTiles > 1) { + // Split reduction dimension into tiles. + // The number of tiles represents the batch dimension. + inputs[0] = rewriter.create( + loc, SmallVector{accShape[0], numTiles, kTile}, inputs[0], + SmallVector{{0}, {1, 2}}); + inputs[1] = rewriter.create( + loc, SmallVector{numTiles, kTile, accShape[1]}, inputs[1], + SmallVector{{0, 1}, {2}}); + + // Update maps with BRGEMM indexing. + auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx); + auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx); + indexingMaps = SmallVector{mapA, mapB, mapC}; + } + + // TODO: Perform this check much earlier before any rewrites. auto brgemmInfo = xsmm::utils::isMappableToBrgemm(rewriter, dotOp, inputs, outputs, indexingMaps); if (failed(brgemmInfo)) { diff --git a/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp index 1e059930c748c..bbc6412a142ba 100644 --- a/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp +++ b/third_party/cpu/lib/Xsmm/ConvertVectorToXsmm.cpp @@ -157,13 +157,14 @@ struct ContractToXsmm : public OpRewritePattern { LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { Location loc = contractOp.getLoc(); + MLIRContext *ctx = contractOp.getContext(); TypedValue lhs = contractOp.getLhs(); TypedValue rhs = contractOp.getRhs(); TypedValue acc = contractOp.getAcc(); - auto vecTy = dyn_cast(acc.getType()); - if (!vecTy) + auto accVecTy = dyn_cast(acc.getType()); + if (!accVecTy) return rewriter.notifyMatchFailure(contractOp, "expects to accumulate on vector"); @@ -177,17 +178,45 @@ struct ContractToXsmm : public OpRewritePattern { SmallVector inputs{lhsBuf, rhsBuf, accBuf}; SmallVector outputs{nullptr}; - auto brgemmInfo = - xsmm::utils::isMappableToBrgemm(rewriter, contractOp, inputs, outputs, - contractOp.getIndexingMapsArray()); + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + + // Rewrite matmul into a BRGEMM. + // This allows for additional reduction dimension tiling driven + // by a microkernel. + // + // TODO: Expand heuristics about brgemm rewrite profitability. + // TODO: Allow for batch dimension. + int64_t kDim = lhs.getType().getShape().back(); + auto accShape = accVecTy.getShape(); + constexpr int64_t kTile = 32; + int64_t numTiles = kDim / kTile; + uint32_t rank = accVecTy.getRank(); + if (rank == 2 && (kDim % kTile) == 0 && numTiles > 1) { + // Split reduction dimension into tiles. + // The number of tiles represents the batch dimension. + inputs[0] = rewriter.create( + loc, SmallVector{accShape[0], numTiles, kTile}, inputs[0], + SmallVector{{0}, {1, 2}}); + inputs[1] = rewriter.create( + loc, SmallVector{numTiles, kTile, accShape[1]}, inputs[1], + SmallVector{{0, 1}, {2}}); + + // Update maps with BRGEMM indexing. + auto mapA = AffineMap::getMultiDimMapWithTargets(4, {1, 0, 3}, ctx); + auto mapB = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto mapC = AffineMap::getMultiDimMapWithTargets(4, {1, 2}, ctx); + indexingMaps = SmallVector{mapA, mapB, mapC}; + } + + auto brgemmInfo = xsmm::utils::isMappableToBrgemm( + rewriter, contractOp, inputs, outputs, indexingMaps); if (failed(brgemmInfo)) { assert(false); // FIXME: getMemrefSource above already modified IR... // return rewriter.notifyMatchFailure(contractOp, "not mappable to XSMM"); } - auto xsmmFuncs = xsmm::utils::buildBrgemmCalls( - rewriter, contractOp, ValueRange{lhsBuf, rhsBuf, accBuf}, - contractOp.getIndexingMapsArray(), flags); + auto xsmmFuncs = xsmm::utils::buildBrgemmCalls(rewriter, contractOp, inputs, + indexingMaps, flags); if (hoistedAcc) { // Hoisting already updated all uses correctly. @@ -198,8 +227,8 @@ struct ContractToXsmm : public OpRewritePattern { Value zeroIdx = rewriter.create(loc, 0); SmallVector indices( dyn_cast(accBuf.getType()).getRank(), zeroIdx); - auto readOp = - rewriter.create(loc, vecTy, accBuf, indices); + auto readOp = rewriter.create(loc, accVecTy, + accBuf, indices); rewriter.replaceOp(contractOp, readOp); }