forked from daphne-eu/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[codegen] Refactor of codegen passes [daphne-eu#889]
This PR majorly reworks codegen for AllAgg* and EwOps as well as add lowering for TransposeOp and Row/ColAgg*. All of these passes are added to the optional MLIR codegen pipeline that can be enabled using the --mlir-codegen flag and offer alternative lowering of these operations to MLIR rather than calls to precompiled C++ kernels. Currently, they only support DenseMatrix with dimensions that are known at compile-time and any value type (except Booleans). Except for IdxMin, IdxMax which are directly lowered to affine loops and TransposeOp which lowers to a named linalg op all passes make use of linalg GenericOps which are then lowered to affine loops in a later pass in the codegen pipeline. They convert the input DenseMatrix to a MemRef and create a new MemRef for the output that is converted into a DenseMatrix. Changes: - Add codegen for AllAgg*Op, Row/ColAgg*Op, Ew*Op and TransposeOp (see below for details) - Added passes to TableGen files and codegen pipeline - Added script level test cases / MLIR test cases (using FileCheck) - Replaced old tests Renamed some old test scripts for EwOps for better organization - Edited fusion.mlir test to lower Linalg to affine loops before applying fusion pass - Added Canonicalization passes for floor, ceil, round that removes the respective ops when input type is an integer (this also simplifies codegen) - Added some necessary instantiations in kernels.json - Restored alphabetic sorting of codegen passes in ir/daphneir/Passes.h Ops with new codegen: - AllAgg*Op Sum, Min, Max - Row/ColAgg*Op Sum, Min, Max, IdxMin, IdxMax - Ew*Op Unary (scalar/matrix): Abs, Sqrt, Exp, Ln, Sin, Cos, Floor, Ceil, Round Binary (scalar-scalar/matrix-matrix/matrix-scalar broadcasting): Add, Sub, Mul, Div, Pow, Max, Min - TransposeOp Known limitations are listed in the PR description [daphne-eu#889] Co-authored-by: philipportner [email protected]
- Loading branch information
1 parent
287f4c5
commit 576bde3
Showing
62 changed files
with
3,225 additions
and
675 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
* Copyright 2024 The DAPHNE Consortium | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "compiler/utils/LoweringUtils.h" | ||
#include "ir/daphneir/Daphne.h" | ||
#include "ir/daphneir/Passes.h" | ||
|
||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" | ||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" | ||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" | ||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" | ||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" | ||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" | ||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/AffineExpr.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Location.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/UseDefLists.h" | ||
#include "mlir/IR/Value.h" | ||
#include "mlir/IR/ValueRange.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
using namespace mlir; | ||
|
||
class TransposeOpLowering : public OpConversionPattern<daphne::TransposeOp> { | ||
public: | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
explicit TransposeOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) | ||
: mlir::OpConversionPattern<daphne::TransposeOp>(typeConverter, ctx, PatternBenefit(1)) { | ||
this->setDebugName("TransposeOpLowering"); | ||
} | ||
|
||
/** | ||
* @brief Replaces a Transpose operation with a Linalg TransposeOp if possible. | ||
* | ||
* @return mlir::success if Transpose has been replaced, else mlir::failure. | ||
*/ | ||
LogicalResult matchAndRewrite(daphne::TransposeOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
daphne::MatrixType matrixType = adaptor.getArg().getType().dyn_cast<daphne::MatrixType>(); | ||
if (!matrixType) { | ||
return failure(); | ||
} | ||
|
||
Location loc = op->getLoc(); | ||
|
||
Type matrixElementType = matrixType.getElementType(); | ||
ssize_t numRows = matrixType.getNumRows(); | ||
ssize_t numCols = matrixType.getNumCols(); | ||
|
||
if (numRows < 0 || numCols < 0) { | ||
return rewriter.notifyMatchFailure( | ||
op, "transposeOp codegen currently only works with matrix dimensions that are known at compile time"); | ||
} | ||
|
||
Value argMemref = rewriter.create<daphne::ConvertDenseMatrixToMemRef>( | ||
loc, MemRefType::get({numRows, numCols}, matrixElementType), adaptor.getArg()); | ||
|
||
Value resMemref = rewriter.create<memref::AllocOp>(loc, MemRefType::get({numCols, numRows}, matrixElementType)); | ||
|
||
DenseI64ArrayAttr permutation = rewriter.getDenseI64ArrayAttr({1, 0}); | ||
rewriter.create<linalg::TransposeOp>(loc, argMemref, resMemref, permutation); | ||
|
||
Value resDenseMatrix = convertMemRefToDenseMatrix(loc, rewriter, resMemref, op.getType()); | ||
|
||
rewriter.replaceOp(op, resDenseMatrix); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
namespace { | ||
/** | ||
* @brief Lowers the daphne::Transpose operator to a Linalg TransposeOp. | ||
* | ||
* This rewrite may enable loop fusion on the affine loops TransposeOp is | ||
* lowered to by running the loop fusion pass. | ||
*/ | ||
struct TransposeLoweringPass : public mlir::PassWrapper<TransposeLoweringPass, mlir::OperationPass<mlir::ModuleOp>> { | ||
explicit TransposeLoweringPass() {} | ||
|
||
StringRef getArgument() const final { return "lower-transpose"; } | ||
StringRef getDescription() const final { return "Lowers Transpose operators to a Linalg TransposeOp."; } | ||
|
||
void getDependentDialects(mlir::DialectRegistry ®istry) const override { | ||
registry.insert<mlir::LLVM::LLVMDialect, mlir::linalg::LinalgDialect, mlir::memref::MemRefDialect>(); | ||
} | ||
void runOnOperation() final; | ||
}; | ||
} // end anonymous namespace | ||
|
||
void TransposeLoweringPass::runOnOperation() { | ||
mlir::ConversionTarget target(getContext()); | ||
mlir::RewritePatternSet patterns(&getContext()); | ||
LowerToLLVMOptions llvmOptions(&getContext()); | ||
LLVMTypeConverter typeConverter(&getContext(), llvmOptions); | ||
|
||
typeConverter.addConversion(convertInteger); | ||
typeConverter.addConversion(convertFloat); | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
typeConverter.addArgumentMaterialization(materializeCastFromIllegal); | ||
typeConverter.addSourceMaterialization(materializeCastToIllegal); | ||
typeConverter.addTargetMaterialization(materializeCastFromIllegal); | ||
|
||
target.addLegalDialect<BuiltinDialect, daphne::DaphneDialect, linalg::LinalgDialect, memref::MemRefDialect>(); | ||
|
||
target.addDynamicallyLegalOp<daphne::TransposeOp>([](Operation *op) { | ||
Type operand = op->getOperand(0).getType(); | ||
daphne::MatrixType matType = operand.dyn_cast<daphne::MatrixType>(); | ||
if (matType && matType.getRepresentation() == daphne::MatrixRepresentation::Dense) { | ||
return false; | ||
} | ||
return true; | ||
}); | ||
|
||
patterns.insert<TransposeOpLowering>(typeConverter, &getContext()); | ||
auto module = getOperation(); | ||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
|
||
std::unique_ptr<mlir::Pass> daphne::createTransposeOpLoweringPass() { | ||
return std::make_unique<TransposeLoweringPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.