From a7df43460fa658bd87be5a48a65a6bc9fd88d5f0 Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Thu, 7 Sep 2023 16:13:35 +0200 Subject: [PATCH] [DAPHNE-455] Introducing a kernel catalog for the DAPHNE compiler. - The DAPHNE compiler usually lowers most domain-specific operations to calls to pre-compiled kernels. - So far, the DAPHNE compiler did not know which kernel instantiations are available in pre-compiled form. - Instead, it generated the expected function name of a kernel based on the DaphneIR operation's mnenomic, its result/argument types, and the processing backend (e.g., CPP or CUDA). - If the expected kernel was not available, an error of the form "JIT session error: Symbols not found: ..." occurred during LLVM JIT compilation. - This commit introduces a kernel catalog that informs the DAPHNE compiler about the available pre-compiled kernels. - The kernel catalog stores a mapping from DaphneIR ops (represented by their mnemonic) to information on kernels registered for the op. - The information stored for each kernel comprises: the name of the pre-compiled C/C++ function, the result/argument types, the processing backend (e.g., CPP or CUDA) - The kernel catalog provides methods for registering a kernel, retrieving the registered kernels for a specific op, and for dumping the catalog. - The kernel catalog is stored inside the DaphneUserConfig. - Makes sense since users will be able to configure the available kernels in the future. - That way, the kernel catalog is accessible in all parts of the DAPHNE compiler and runtime. - The information on the available kernels is stored in a JSON file named catalog.json (or CUDAcatalog.json). - Currently, catalog.json is generated by genKernelInst.py; thus, the system has access to the same kernel specializations as before. - catalog.json is read at DAPHNE system start-up in the coordinator and distributed workers. - Added a parser for the kernel catalog JSON file. - RewriteToCallKernelOpPass uses the kernel catalog to obtain the kernel function name for an operation, instead of relying on a naming convention. - However, there are still a few points where kernel function names are built by convention (to be addressed later): - lowering of DistributedPipelineOp in RewriteToCallKernelOpPass - lowering of MapOp in LowerToLLVMPass - lowering of VectorizedPipelineOp in LowerToLLVMPass - Directly related misc changes - DaphneIrExecutor has getters for its DaphneUserConfig. - CompilerUtils::mlirTypeToCppTypeName() allows generating either underscores (as before) or angle brackets (new) for template parameters. - This is a first step towards extensibility w.r.t. the kernels, for now the main contribution is the representation of the available kernels in a data structure (the kernel catalog). - Contributes to #455, but doesn't close it yet. --- CMakeLists.txt | 1 + src/api/cli/DaphneUserConfig.h | 3 + src/api/internal/CMakeLists.txt | 1 + src/api/internal/daphne_internal.cpp | 29 ++- src/compiler/catalog/KernelCatalog.h | 163 ++++++++++++ src/compiler/execution/DaphneIrExecutor.cpp | 2 +- src/compiler/execution/DaphneIrExecutor.h | 9 + src/compiler/lowering/LowerToLLVMPass.cpp | 12 +- .../lowering/RewriteToCallKernelOpPass.cpp | 237 +++++++++++++----- src/compiler/utils/CompilerUtils.h | 40 ++- src/ir/daphneir/DaphneDialect.cpp | 4 +- src/ir/daphneir/DaphneTypes.td | 5 + src/ir/daphneir/Passes.h | 2 +- src/parser/catalog/CMakeLists.txt | 17 ++ src/parser/catalog/KernelCatalogParser.cpp | 121 +++++++++ src/parser/catalog/KernelCatalogParser.h | 71 ++++++ src/runtime/distributed/worker/CMakeLists.txt | 1 + src/runtime/distributed/worker/WorkerImpl.cpp | 7 + src/runtime/local/kernels/CMakeLists.txt | 8 +- src/runtime/local/kernels/genKernelInst.py | 73 ++++-- 20 files changed, 688 insertions(+), 118 deletions(-) create mode 100644 src/compiler/catalog/KernelCatalog.h create mode 100644 src/parser/catalog/CMakeLists.txt create mode 100644 src/parser/catalog/KernelCatalogParser.cpp create mode 100644 src/parser/catalog/KernelCatalogParser.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 446ebf9bc..616ba696e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,6 +183,7 @@ add_subdirectory(src/compiler/inference) add_subdirectory(src/compiler/lowering) add_subdirectory(src/compiler/utils) add_subdirectory(src/parser) +add_subdirectory(src/parser/catalog) add_subdirectory(src/parser/config) add_subdirectory(src/parser/metadata) add_subdirectory(src/runtime/distributed/proto) diff --git a/src/api/cli/DaphneUserConfig.h b/src/api/cli/DaphneUserConfig.h index 3b7a2de93..a308f6f4b 100644 --- a/src/api/cli/DaphneUserConfig.h +++ b/src/api/cli/DaphneUserConfig.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -101,4 +102,6 @@ struct DaphneUserConfig { // TODO Maybe the DaphneLib result should better reside in the DaphneContext, // but having it here is simpler for now. DaphneLibResult* result_struct = nullptr; + + KernelCatalog kernelCatalog; }; diff --git a/src/api/internal/CMakeLists.txt b/src/api/internal/CMakeLists.txt index 95e3ec247..343f8e5b6 100644 --- a/src/api/internal/CMakeLists.txt +++ b/src/api/internal/CMakeLists.txt @@ -26,6 +26,7 @@ set(LIBS DaphneDSLParser DaphneIrExecutor DaphneConfigParser + DaphneCatalogParser DaphneMetaDataParser Util WorkerImpl diff --git a/src/api/internal/daphne_internal.cpp b/src/api/internal/daphne_internal.cpp index 5ba81007c..5f783d907 100644 --- a/src/api/internal/daphne_internal.cpp +++ b/src/api/internal/daphne_internal.cpp @@ -25,6 +25,7 @@ #include #include "compiler/execution/DaphneIrExecutor.h" #include +#include #include #include @@ -43,8 +44,11 @@ #include #include #include + #include #include +#include +#include // global logger handle for this executable static std::unique_ptr logger; @@ -505,18 +509,35 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int } // ************************************************************************ - // Parse, compile and execute DaphneDSL script + // Create DaphneIrExecutor and get MLIR context // ************************************************************************ - clock::time_point tpBegPars = clock::now(); - // Creates an MLIR context and loads the required MLIR dialects. DaphneIrExecutor executor(selectMatrixRepr, user_config); + mlir::MLIRContext * mctx = executor.getContext(); + + // ************************************************************************ + // Populate kernel extension catalog + // ************************************************************************ + + KernelCatalog & kc = executor.getUserConfig().kernelCatalog; + // kc.dump(); + KernelCatalogParser kcp(mctx); + kcp.parseKernelCatalog("build/src/runtime/local/kernels/catalog.json", kc); + if(user_config.use_cuda) + kcp.parseKernelCatalog("build/src/runtime/local/kernels/CUDAcatalog.json", kc); + // kc.dump(); + + // ************************************************************************ + // Parse, compile and execute DaphneDSL script + // ************************************************************************ + + clock::time_point tpBegPars = clock::now(); // Create an OpBuilder and an MLIR module and set the builder's insertion // point to the module's body, such that subsequently created DaphneIR // operations are inserted into the module. - OpBuilder builder(executor.getContext()); + OpBuilder builder(mctx); auto loc = mlir::FileLineColLoc::get(builder.getStringAttr(inputFile), 0, 0); auto moduleOp = ModuleOp::create(loc); auto * body = moduleOp.getBody(); diff --git a/src/compiler/catalog/KernelCatalog.h b/src/compiler/catalog/KernelCatalog.h new file mode 100644 index 000000000..c58666454 --- /dev/null +++ b/src/compiler/catalog/KernelCatalog.h @@ -0,0 +1,163 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include +#include +#include +#include + +/** + * @brief Stores information on a single kernel. + */ +struct KernelInfo { + /** + * @brief The name of the pre-compiled kernel function. + */ + const std::string kernelFuncName; + + // TODO Add the path to the shared library containing the kernel function. + + /** + * @brief The kernel's result types. + */ + const std::vector resTypes; + + /** + * @brief The kernel's argument types. + */ + const std::vector argTypes; + + // TODO Maybe unify this with ALLOCATION_TYPE. + /** + * @brief The targeted backend (e.g., hardware accelerator). + */ + const std::string backend; + + KernelInfo( + const std::string kernelFuncName, + const std::vector resTypes, + const std::vector argTypes, + const std::string backend + // TODO Add the path to the shared library containing the kernel function. + ) : + kernelFuncName(kernelFuncName), resTypes(resTypes), argTypes(argTypes), backend(backend) + { + // + } +}; + +/** + * @brief Stores information on kernels registered in the DAPHNE compiler. + */ +class KernelCatalog { + /** + * @brief The central data structure mapping DaphneIR operations to registered kernels. + * + * The DaphneIR operation is represented by its mnemonic. The kernels are represented + * by their kernel information. + */ + std::unordered_map> kernelInfosByOp; + + /** + * @brief Prints the given kernel information. + * + * @param opMnemonic The mnemonic of the corresponding DaphneIR operation. + * @param kernelInfos The kernel information to print. + * @param os The stream to print to. Defaults to `std::cerr`. + */ + void dumpKernelInfos(const std::string & opMnemonic, const std::vector & kernelInfos, std::ostream & os = std::cerr) const { + os << "- operation `" << opMnemonic << "` (" << kernelInfos.size() << " kernels)" << std::endl; + for(KernelInfo ki : kernelInfos) { + os << " - kernel `" << ki.kernelFuncName << "`: ("; + for(size_t i = 0; i < ki.argTypes.size(); i++) { + os << ki.argTypes[i]; + if(i < ki.argTypes.size() - 1) + os << ", "; + } + os << ") -> ("; + for(size_t i = 0; i < ki.resTypes.size(); i++) { + os << ki.resTypes[i]; + if(i < ki.resTypes.size() - 1) + os << ", "; + } + os << ") for backend `" << ki.backend << '`' << std::endl; + } + } + +public: + /** + * @brief Registers the given kernel information as a kernel for the DaphneIR + * operation with the given mnemonic. + * + * @param opMnemonic The DaphneIR operation's mnemonic. + * @param kernelInfo The information on the kernel. + */ + void registerKernel(std::string opMnemonic, KernelInfo kernelInfo) { + kernelInfosByOp[opMnemonic].push_back(kernelInfo); + } + + /** + * @brief Retrieves information on all kernels registered for the given DaphneIR operation. + * + * @param opMnemonic The mnemonic of the DaphneIR operation. + * @return A vector of kernel information, or an empty vector if no kernels are registered + * for the given operation. + */ + const std::vector getKernelInfosByOp(const std::string & opMnemonic) const { + auto it = kernelInfosByOp.find(opMnemonic); + if(it != kernelInfosByOp.end()) + return it->second; + else + return {}; + } + + /** + * @brief Prints high-level statistics on the kernel catalog. + * + * @param os The stream to print to. Defaults to `std::cerr`. + */ + void stats(std::ostream & os = std::cerr) const { + const size_t numOps = kernelInfosByOp.size(); + size_t numKernels = 0; + for(auto it = kernelInfosByOp.begin(); it != kernelInfosByOp.end(); it++) + numKernels += it->second.size(); + os << "KernelCatalog (" << numOps << " ops, " << numKernels << " kernels)" << std::endl; + } + + /** + * @brief Prints this kernel catalog. + * + * @param opMnemonic If an empty string, print registered kernels for all DaphneIR + * operations; otherwise, consider only the specified DaphneIR operation. + * @param os The stream to print to. Defaults to `std::cerr`. + */ + void dump(std::string opMnemonic = "", std::ostream & os = std::cerr) const { + stats(os); + if(opMnemonic.empty()) + // Print info on all ops. + for(auto it = kernelInfosByOp.begin(); it != kernelInfosByOp.end(); it++) + dumpKernelInfos(it->first, it->second, os); + else + // Print info on specified op only. + dumpKernelInfos(opMnemonic, getKernelInfosByOp(opMnemonic), os); + } +}; \ No newline at end of file diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index 1c5ab19f5..a01c0cf3c 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -213,7 +213,7 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { "IR after managing object references:")); pm.addNestedPass( - mlir::daphne::createRewriteToCallKernelOpPass()); + mlir::daphne::createRewriteToCallKernelOpPass(userConfig_)); if (userConfig_.explain_kernels) pm.addPass( mlir::daphne::createPrintIRPass("IR after kernel lowering:")); diff --git a/src/compiler/execution/DaphneIrExecutor.h b/src/compiler/execution/DaphneIrExecutor.h index ef1c32d13..89396396b 100644 --- a/src/compiler/execution/DaphneIrExecutor.h +++ b/src/compiler/execution/DaphneIrExecutor.h @@ -31,6 +31,15 @@ class DaphneIrExecutor mlir::MLIRContext *getContext() { return &context_; } + + DaphneUserConfig & getUserConfig() { + return userConfig_; + } + + const DaphneUserConfig & getUserConfig() const { + return userConfig_; + } + private: mlir::MLIRContext context_; DaphneUserConfig userConfig_; diff --git a/src/compiler/lowering/LowerToLLVMPass.cpp b/src/compiler/lowering/LowerToLLVMPass.cpp index 6fd9c975e..6ea82811f 100644 --- a/src/compiler/lowering/LowerToLLVMPass.cpp +++ b/src/compiler/lowering/LowerToLLVMPass.cpp @@ -526,10 +526,10 @@ class MapOpLowering : public OpConversionPattern callee << '_' << op->getName().stripDialect().str(); // Result Matrix - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(op.getType()); + callee << "__" << CompilerUtils::mlirTypeToCppTypeName(op.getType(), false); // Input Matrix - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(op.getArg().getType()); + callee << "__" << CompilerUtils::mlirTypeToCppTypeName(op.getArg().getType(), false); // Pointer to UDF callee << "__void"; @@ -740,7 +740,7 @@ class VectorizedPipelineOpLowering : public OpConversionPattern(loc, daphne::VariadicPackType::get(rewriter.getContext(), operandType), @@ -805,11 +805,11 @@ class VectorizedPipelineOpLowering : public OpConversionPattern #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" @@ -32,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -134,38 +136,79 @@ namespace */ Value dctx; + const DaphneUserConfig & userConfig; + + mlir::Type adaptType(mlir::Type t, bool generalizeToStructure) const { + MLIRContext * mctx = t.getContext(); + if(generalizeToStructure && t.isa()) + return mlir::daphne::StructureType::get(mctx); + if(auto mt = t.dyn_cast()) + return mt.withSameElementTypeAndRepr(); + if(t.isa()) + return mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}); + if(auto mrt = t.dyn_cast()) + // Remove any dimension information ({0, 0}), but retain the element type. + return mlir::MemRefType::get({0, 0}, mrt.getElementType()); + return t; + } + public: /** * Creates a new KernelReplacement rewrite pattern. * * @param mctx The MLIR context. * @param dctx The DaphneContext to pass to the kernels. + * @param userConfig The user config. * @param benefit */ - KernelReplacement(MLIRContext * mctx, Value dctx, PatternBenefit benefit = 1) - : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, mctx), dctx(dctx) + KernelReplacement(MLIRContext * mctx, Value dctx, const DaphneUserConfig & userConfig, PatternBenefit benefit = 1) + : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, mctx), dctx(dctx), userConfig(userConfig) { } + /** + * @brief Rewrites the given operation to a `CallKernelOp`. + * + * This involves looking up a matching kernel from the kernel catalog based on the + * mnemonic, argument/result types, and backend (e.g., hardware accelerator) of the + * given operation. Variadic operands are also taken into account. + * + * @param op The operation to rewrite. + * @param rewriter The rewriter. + * @result Always returns `mlir::success()` unless an exception is thrown. + */ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); - // Determine the name of the kernel function to call by convention - // based on the DaphneIR operation and the types of its results and - // arguments. - std::stringstream callee; - - // check CUDA support and valid device ID - if(op->hasAttr("cuda_device")) { - callee << "CUDA"; - } - else if(op->hasAttr("fpgaopencl_device")) { - callee << "FPGAOPENCL"; - } - - callee << '_' << op->getName().stripDialect().data(); - - + // The argument/result types of the given operation. + Operation::operand_type_range opArgTys = op->getOperandTypes(); + Operation::result_type_range opResTys = op->getResultTypes(); + + // The argument/result types to use for kernel look-up. + std::vector lookupArgTys; + std::vector lookupResTys; + // Differences between op argument types and look-up argument types: + // - The look-up argument types summarize n occurrences of a variadic operand into + // one variadic pack and one number of occurrences. + // - The look-up argument types omit most of the properties of the op argument types, + // because those would complicate the search for matching kernels. + // Differences between op result types and look-up result types: + // - The look-up result types omit most of the properties of the op result types, + // because those would complicate the search for matching kernels. + + // The operands to use for the CallKernelOp to be created. These may differ from + // the operands of the given operation, if it has a variadic operand. + std::vector kernelArgs; + + // ***************************************************************************** + // Prepare the kernel look-up and the creation of the CallKernelOp + // ***************************************************************************** + // Determine the argument/result types for the kernel look-up as well as + // the arguments of the CallKernelOp to be created. Variadic operands are taken + // into account. + + // Find out if argument types shall the generalized from matrix/frame to the + // supertype structure. // TODO Don't enumerate all ops, decide based on a trait. const bool generalizeInputTypes = llvm::isa(op) || @@ -176,19 +219,13 @@ namespace llvm::isa(op) || llvm::isa(op); - // Append names of result types to the kernel name. - Operation::result_type_range resultTypes = op->getResultTypes(); - for(size_t i = 0; i < resultTypes.size(); i++) - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(resultTypes[i], false); + // Append converted op result types to the look-up result types. + for(size_t i = 0; i < opResTys.size(); i++) + lookupResTys.push_back(adaptType(opResTys[i], false)); - // Append names of operand types to the kernel name. Variadic - // operands, which can have an arbitrary number of occurrences, are + // Append converted op argument types to the look-up argument types. + // Variadic operands, which can have an arbitrary number of occurrences, are // treated specially. - Operation::operand_type_range operandTypes = op->getOperandTypes(); - // The operands of the CallKernelOp may differ from the operands - // of the given operation, if it has a variadic operand. - std::vector newOperands; - if( // TODO Unfortunately, one needs to know the exact N for // AtLeastNOperands... There seems to be no simple way to @@ -213,16 +250,17 @@ namespace // expectation of a aggregation. To make the group operation possible without aggregations, // we have to use this workaround to create the correct name and skip the creation // of the variadic pack ops. Should be changed when reworking the lowering to kernels. - if(llvm::dyn_cast(op) && idx >= operandTypes.size()) { - callee << "__char_variadic__size_t"; + if(llvm::dyn_cast(op) && idx >= opArgTys.size()) { + lookupArgTys.push_back(adaptType(mlir::daphne::StringType::get(getContext()), generalizeInputTypes)); + lookupArgTys.push_back(rewriter.getIndexType()); continue; } else { - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandTypes[idx], generalizeInputTypes); + lookupArgTys.push_back(adaptType(opArgTys[idx], generalizeInputTypes)); } if(isVariadic) { // Variadic operand. - callee << "_variadic__size_t"; + lookupArgTys.push_back(rewriter.getIndexType()); auto cvpOp = rewriter.create( loc, daphne::VariadicPackType::get( @@ -238,33 +276,31 @@ namespace op->getOperand(idx + k), rewriter.getI64IntegerAttr(k) ); - newOperands.push_back(cvpOp); - newOperands.push_back(rewriter.create( + kernelArgs.push_back(cvpOp); + kernelArgs.push_back(rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(len)) ); } else // Non-variadic operand. - newOperands.push_back(op->getOperand(i)); + kernelArgs.push_back(op->getOperand(i)); } } else // For operations without variadic operands, we simply append - // the name of the type of each operand and pass all operands - // to the CallKernelOp as-is. - for(size_t i = 0; i < operandTypes.size(); i++) { - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandTypes[i], generalizeInputTypes); - newOperands.push_back(op->getOperand(i)); + // the type of each operand to the vector of types to use for + // kernel look-up, and pass all operands to the CallKernelOp as-is. + for(size_t i = 0; i < opArgTys.size(); i++) { + lookupArgTys.push_back(adaptType(opArgTys[i], generalizeInputTypes)); + kernelArgs.push_back(op->getOperand(i)); } if(auto groupOp = llvm::dyn_cast(op)) { // GroupOp carries the aggregation functions to apply as an - // attribute. Since attributes to not automatically become + // attribute. Since attributes do not automatically become // inputs to the kernel call, we need to add them explicitly // here. - callee << "__GroupEnum_variadic__size_t"; - ArrayAttr aggFuncs = groupOp.getAggFuncs(); const size_t numAggFuncs = aggFuncs.size(); const Type t = rewriter.getIntegerType(32, false); @@ -290,22 +326,18 @@ namespace ), rewriter.getI64IntegerAttr(k++) ); - newOperands.push_back(cvpOp); - newOperands.push_back(rewriter.create( + kernelArgs.push_back(cvpOp); + kernelArgs.push_back(rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(numAggFuncs)) ); } - if(auto thetaJoinOp = llvm::dyn_cast(op)) { // ThetaJoinOp carries multiple CompareOperation as an - // attribute. Since attributes to not automatically become + // attribute. Since attributes do not automatically become // inputs to the kernel call, we need to add them explicitly // here. - // manual mapping of attributes to function header - callee << "__CompareOperation__size_t"; - // get array of CompareOperations ArrayAttr compareOperations = thetaJoinOp.getCmp(); const size_t numCompareOperations = compareOperations.size(); @@ -336,14 +368,14 @@ namespace ); // add created variadic pack and size of this pack as // new operands / parameters of the ThetaJoin-Kernel call - newOperands.push_back(cvpOp); - newOperands.push_back(rewriter.create( + kernelArgs.push_back(cvpOp); + kernelArgs.push_back(rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(numCompareOperations)) ); } if(auto distCompOp = llvm::dyn_cast(op)) { - MLIRContext newContext; + MLIRContext newContext; // TODO Reuse the existing context. OpBuilder tempBuilder(&newContext); std::string funcName = "dist"; @@ -363,24 +395,89 @@ namespace auto strTy = daphne::StringType::get(rewriter.getContext()); Value rewriteStr = rewriter.create(loc, strTy, rewriter.getStringAttr(stream.str())); - callee << "__" << CompilerUtils::mlirTypeToCppTypeName(strTy, false); - newOperands.push_back(rewriteStr); + lookupArgTys.push_back(mlir::daphne::StringType::get(&newContext)); + kernelArgs.push_back(rewriteStr); } // Inject the current DaphneContext as the last input parameter to // all kernel calls, unless it's a CreateDaphneContextOp. if(!llvm::isa(op)) - newOperands.push_back(dctx); + kernelArgs.push_back(dctx); + + // ***************************************************************************** + // Look up a matching kernel from the kernel catalog. + // ***************************************************************************** + + const KernelCatalog & kc = userConfig.kernelCatalog; + const std::string opMnemonic = op->getName().stripDialect().data(); + std::vector kernelInfos = kc.getKernelInfosByOp(opMnemonic); + + if(kernelInfos.empty()) + throw std::runtime_error("no kernels registered for operation " + opMnemonic); + + std::string backend; + if(op->hasAttr("cuda_device")) + backend = "CUDA"; + else if(op->hasAttr("fpgaopencl_device")) + backend = "FPGAOPENCL"; + else + backend = "CPP"; + + const size_t numArgs = lookupArgTys.size(); + const size_t numRess = lookupResTys.size(); + int chosenKernelIdx = -1; + for(size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) { + auto ki = kernelInfos[i]; + if(ki.backend != backend) + continue; + if(numArgs != ki.argTypes.size()) + continue; + if(numRess != ki.resTypes.size()) + continue; + + bool mismatch = false; + for(size_t i = 0; i < numArgs && !mismatch; i++) + if(lookupArgTys[i] != ki.argTypes[i]) + mismatch = true; + for(size_t i = 0; i < numRess && !mismatch; i++) + if(lookupResTys[i] != ki.resTypes[i]) + mismatch = true; + if(!mismatch) + chosenKernelIdx = i; + } + if(chosenKernelIdx == -1) { + std::stringstream s; + s << "no kernel for operation `" << opMnemonic + << "` available for the required input types `("; + for(size_t i = 0; i < numArgs; i++) { + s << lookupArgTys[i]; + if(i < numArgs - 1) + s << ", "; + } + s << + ")` and output types `("; + for(size_t i = 0; i < numRess; i++) { + s << lookupResTys[i]; + if(i < numRess - 1) + s << ", "; + } + s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl; + kc.dump(opMnemonic, s); + throw std::runtime_error(s.str()); + } + KernelInfo chosenKI = kernelInfos[chosenKernelIdx]; - // Create a CallKernelOp for the kernel function to call and return - // success(). + // ***************************************************************************** + // Create the CallKernelOp + // ***************************************************************************** + + // Create a CallKernelOp for the kernel function to call and return success(). auto kernel = rewriter.create( loc, - callee.str(), - newOperands, - op->getResultTypes() - ); + chosenKI.kernelFuncName, + kernelArgs, + opResTys + ); rewriter.replaceOp(op, kernel.getResults()); return success(); } @@ -388,11 +485,12 @@ namespace class DistributedPipelineKernelReplacement : public OpConversionPattern { Value dctx; + const DaphneUserConfig & userConfig; public: using OpConversionPattern::OpConversionPattern; - DistributedPipelineKernelReplacement(MLIRContext * mctx, Value dctx, PatternBenefit benefit = 2) - : OpConversionPattern(mctx, benefit), dctx(dctx) + DistributedPipelineKernelReplacement(MLIRContext * mctx, Value dctx, const DaphneUserConfig & userConfig, PatternBenefit benefit = 2) + : OpConversionPattern(mctx, benefit), dctx(dctx), userConfig(userConfig) { } @@ -486,7 +584,8 @@ namespace struct RewriteToCallKernelOpPass : public PassWrapper> { - RewriteToCallKernelOpPass() = default; + const DaphneUserConfig& userConfig; + explicit RewriteToCallKernelOpPass(const DaphneUserConfig& cfg) : userConfig(cfg) {} void runOnOperation() final; }; } @@ -534,13 +633,13 @@ void RewriteToCallKernelOpPass::runOnOperation() patterns.insert< KernelReplacement, DistributedPipelineKernelReplacement - >(&getContext(), dctx); + >(&getContext(), dctx, userConfig); if (failed(applyPartialConversion(func, target, std::move(patterns)))) signalPassFailure(); } -std::unique_ptr daphne::createRewriteToCallKernelOpPass() +std::unique_ptr daphne::createRewriteToCallKernelOpPass(const DaphneUserConfig& cfg) { - return std::make_unique(); + return std::make_unique(cfg); } diff --git a/src/compiler/utils/CompilerUtils.h b/src/compiler/utils/CompilerUtils.h index 13e4973b6..7c38dde3a 100644 --- a/src/compiler/utils/CompilerUtils.h +++ b/src/compiler/utils/CompilerUtils.h @@ -123,10 +123,15 @@ struct CompilerUtils { * might complain about recursion. * * @param t MLIR type name - * @param generalizeToStructure If true, "Structure" is used instead of derived types DenseMatrix et al. - * @return string representation of the C++ type names + * @param angleBrackets If `true` (default), angle brackets are used for C++ template types (e.g., `DenseMatrix`); + * Otherwise, underscores are used (e.g., `DenseMatrix_float`). + * @param generalizeToStructure If `true`, `Structure` is used instead of derived types like `DenseMatrix` etc. + * @return A string representation of the C++ type names */ - static std::string mlirTypeToCppTypeName(mlir::Type t, bool generalizeToStructure = false) { // NOLINT(misc-no-recursion) + // TODO The parameter generalizeToStructure seems to be used only by some remaining kernel name generation + // in LowerToLLVMPass. Once those call-sites have been refactored to use the kernel catalog, this feature + // can be removed here. + static std::string mlirTypeToCppTypeName(mlir::Type t, bool angleBrackets = true, bool generalizeToStructure = false) { // NOLINT(misc-no-recursion) if(t.isF64()) return "double"; else if(t.isF32()) @@ -147,22 +152,30 @@ struct CompilerUtils { return "bool"; else if(t.isIndex()) return "size_t"; - else if(auto matTy = t.dyn_cast()) + else if(t.isa()) + return "Structure"; + else if(auto matTy = t.dyn_cast()) { if(generalizeToStructure) return "Structure"; else { switch (matTy.getRepresentation()) { - case mlir::daphne::MatrixRepresentation::Dense: - return "DenseMatrix_" + mlirTypeToCppTypeName(matTy.getElementType(), false); - case mlir::daphne::MatrixRepresentation::Sparse: - return "CSRMatrix_" + mlirTypeToCppTypeName(matTy.getElementType(), false); + case mlir::daphne::MatrixRepresentation::Dense: { + const std::string vtName = mlirTypeToCppTypeName(matTy.getElementType(), angleBrackets, false); + return angleBrackets ? ("DenseMatrix<" + vtName + ">") : ("DenseMatrix_" + vtName); + } + case mlir::daphne::MatrixRepresentation::Sparse: { + const std::string vtName = mlirTypeToCppTypeName(matTy.getElementType(), angleBrackets, false); + return angleBrackets ? ("CSRMatrix<" + vtName + ">") : ("CSRMatrix_" + vtName); + } } } - else if(t.isa()) + } + else if(t.isa()) { if(generalizeToStructure) return "Structure"; else return "Frame"; + } else if(t.isa()) // This becomes "const char *" (which makes perfect sense for // strings) when inserted into the typical "const DT *" template of @@ -170,8 +183,10 @@ struct CompilerUtils { return "char"; else if(t.isa()) return "DaphneContext"; - else if(auto handleTy = t.dyn_cast()) - return "Handle_" + mlirTypeToCppTypeName(handleTy.getDataType(), generalizeToStructure); + else if(auto handleTy = t.dyn_cast()) { + const std::string tName = mlirTypeToCppTypeName(handleTy.getDataType(), angleBrackets, generalizeToStructure); + return angleBrackets ? ("Handle<" + tName + ">") : ("Handle_" + tName); + } else if(t.isa()) return "File"; else if(t.isa()) @@ -179,7 +194,8 @@ struct CompilerUtils { else if(t.isa()) return "Target"; else if(auto memRefType = t.dyn_cast()) { - return "StridedMemRefType_" + mlirTypeToCppTypeName(memRefType.getElementType(), false) + "_2"; + const std::string vtName = mlirTypeToCppTypeName(memRefType.getElementType(), angleBrackets, false); + return angleBrackets ? ("StridedMemRefType<" + vtName + ",2>") : ("StridedMemRefType_" + vtName + "_2"); } std::string typeName; diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 205e7c4e9..e6a4bc5d1 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -243,7 +243,9 @@ std::string unknownStrIf(double val) { void mlir::daphne::DaphneDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - if (auto t = type.dyn_cast()) { + if (type.isa()) + os << "Structure"; + else if (auto t = type.dyn_cast()) { os << "Matrix<" << unknownStrIf(t.getNumRows()) << 'x' << unknownStrIf(t.getNumCols()) << 'x' diff --git a/src/ir/daphneir/DaphneTypes.td b/src/ir/daphneir/DaphneTypes.td index 4dd037d29..52e84c03f 100644 --- a/src/ir/daphneir/DaphneTypes.td +++ b/src/ir/daphneir/DaphneTypes.td @@ -45,6 +45,11 @@ def Unknown : Daphne_Type<"Unknown"> { // Data types // **************************************************************************** +def Structure : Daphne_Type<"Structure"> { + // Don't use this in TableGen! + let summary = "structure"; +} + // A matrix type. def Matrix : Daphne_Type<"Matrix"> { let summary = "matrix"; diff --git a/src/ir/daphneir/Passes.h b/src/ir/daphneir/Passes.h index ec9c5f45a..11b911e7b 100644 --- a/src/ir/daphneir/Passes.h +++ b/src/ir/daphneir/Passes.h @@ -58,7 +58,7 @@ namespace mlir::daphne { std::unique_ptr createPhyOperatorSelectionPass(); std::unique_ptr createPrintIRPass(std::string message = ""); std::unique_ptr createRewriteSqlOpPass(); - std::unique_ptr createRewriteToCallKernelOpPass(); + std::unique_ptr createRewriteToCallKernelOpPass(const DaphneUserConfig& cfg); std::unique_ptr createSelectMatrixRepresentationsPass(); std::unique_ptr createSpecializeGenericFunctionsPass(const DaphneUserConfig& cfg); std::unique_ptr createVectorizeComputationsPass(); diff --git a/src/parser/catalog/CMakeLists.txt b/src/parser/catalog/CMakeLists.txt new file mode 100644 index 000000000..fcb4e31de --- /dev/null +++ b/src/parser/catalog/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(DaphneCatalogParser STATIC + KernelCatalogParser.cpp +) diff --git a/src/parser/catalog/KernelCatalogParser.cpp b/src/parser/catalog/KernelCatalogParser.cpp new file mode 100644 index 000000000..a593aca8f --- /dev/null +++ b/src/parser/catalog/KernelCatalogParser.cpp @@ -0,0 +1,121 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +KernelCatalogParser::KernelCatalogParser(mlir::MLIRContext * mctx) { + // Initialize the mapping from C++ type name strings to MLIR types for parsing. + + mlir::OpBuilder builder(mctx); + + // Scalars and matrices. + std::vector scalarTypes = { + builder.getF64Type(), + builder.getF32Type(), + builder.getIntegerType(64, true), + builder.getIntegerType(32, true), + builder.getIntegerType(8, true), + builder.getIntegerType(64, false), + builder.getIntegerType(32, false), + builder.getIntegerType(8, false), + builder.getI1Type(), + builder.getIndexType(), + mlir::daphne::StringType::get(mctx) + }; + for(mlir::Type st : scalarTypes) { + // Scalar type. + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(st), st); + // Matrix type for DenseMatrix. + // TODO This should have withRepresentation(mlir::daphne::MatrixRepresentation::Dense). + mlir::Type mtDense = mlir::daphne::MatrixType::get(mctx, st); + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mtDense), mtDense); + // Matrix type for CSRMatrix. + mlir::Type mtCSR = mlir::daphne::MatrixType::get(mctx, st).withRepresentation(mlir::daphne::MatrixRepresentation::Sparse); + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mtCSR), mtCSR); + // MemRef type. + if(!st.isa()) { + // DAPHNE's StringType is not supported as the element type of a MemRef. + // The dimensions of the MemRef are irrelevant here, so we use {0, 0}. + mlir::Type mrt = mlir::MemRefType::get({0, 0}, st); + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(mrt), mrt); + } + } + + // Structure, Frame, DaphneContext, MemRef. + std::vector otherTypes = { + mlir::daphne::StructureType::get(mctx), + mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}), + mlir::daphne::DaphneContextType::get(mctx), + }; + for(mlir::Type t : otherTypes) { + typeMap.emplace(CompilerUtils::mlirTypeToCppTypeName(t), t); + } +} + +void KernelCatalogParser::mapTypes( + const std::vector & in, + std::vector & out, + const std::string & word, + const std::string & kernelFuncName, + const std::string & opMnemonic, + const std::string & backend +) const { + for(size_t i = 0; i < in.size(); i++) { + const std::string name = in[i]; + auto it = typeMap.find(name); + if(it != typeMap.end()) + out.push_back(it->second); + else { + std::stringstream s; + s << "KernelCatalogParser: error while parsing " + word + " types of kernel `" + << kernelFuncName << "` for operation `" << opMnemonic << "` (backend `" + << backend << "`): unknown type for " << word << " #" << i << ": `" << name << '`'; + throw std::runtime_error(s.str()); + } + } +} + +void KernelCatalogParser::parseKernelCatalog(const std::string & filePath, KernelCatalog & kc) const { + std::ifstream kernelsConfigFile(filePath); + nlohmann::json kernelsConfigData = nlohmann::json::parse(kernelsConfigFile); + for(auto kernelData : kernelsConfigData) { + const std::string opMnemonic = kernelData["opMnemonic"].get(); + // TODO Remove this workaround. + // Skip these two problematic operations, which return multiple results in the wrong way. + if(opMnemonic == "Avg_Forward" || opMnemonic == "Max_Forward") + continue; + const std::string kernelFuncName = kernelData["kernelFuncName"].get(); + const std::string backend = kernelData["backend"].get(); + std::vector resTypes; + mapTypes(kernelData["resTypes"], resTypes, "result", kernelFuncName, opMnemonic, backend); + std::vector argTypes; + mapTypes(kernelData["argTypes"], argTypes, "argument", kernelFuncName, opMnemonic, backend); + kc.registerKernel(opMnemonic, KernelInfo(kernelFuncName, resTypes, argTypes, backend)); + } +} \ No newline at end of file diff --git a/src/parser/catalog/KernelCatalogParser.h b/src/parser/catalog/KernelCatalogParser.h new file mode 100644 index 000000000..7e93ae260 --- /dev/null +++ b/src/parser/catalog/KernelCatalogParser.h @@ -0,0 +1,71 @@ +/* + * Copyright 2023 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +#include +#include +#include + +/** + * @brief A parser for kernel information. + */ +class KernelCatalogParser { + + /** + * @brief A mapping from C++ type name strings to MLIR types used for parsing input/output types of kernels. + */ + std::unordered_map typeMap; + + /** + * @brief Maps the given C++ type names to MLIR types. + * + * @param in The vector of C++ type names. + * @param out The vector of corresponding MLIR types. + * @param word Typically either `"argument"` or `"result"`. + * @param kernelFuncName The name of the kernel function for which this method is called (for error message). + * @param opMnemonic The mnemonic of the operation for which this method is called (for error message). + * @param backend The backend for which this method is called (for error message). + */ + void mapTypes( + const std::vector & in, + std::vector & out, + const std::string & word, + const std::string & kernelFuncName, + const std::string & opMnemonic, + const std::string & backend + ) const; + +public: + + /** + * @brief Creates a new kernel catalog parser. + */ + KernelCatalogParser(mlir::MLIRContext * mctx); + + /** + * @brief Parses kernel information from the given file and registers them with the given kernel catalog. + * + * @param filePath The path to the file to extract kernel information from. + * @param kc The kernel catalog to register the kernels with. + */ + void parseKernelCatalog(const std::string & filePath, KernelCatalog & kc) const; +}; \ No newline at end of file diff --git a/src/runtime/distributed/worker/CMakeLists.txt b/src/runtime/distributed/worker/CMakeLists.txt index dd2b81688..319cd6685 100644 --- a/src/runtime/distributed/worker/CMakeLists.txt +++ b/src/runtime/distributed/worker/CMakeLists.txt @@ -39,6 +39,7 @@ set(LIBS IO CallData Proto + DaphneCatalogParser DaphneMetaDataParser Arrow::arrow_shared Parquet::parquet_shared diff --git a/src/runtime/distributed/worker/WorkerImpl.cpp b/src/runtime/distributed/worker/WorkerImpl.cpp index 059347b8d..d13f82826 100644 --- a/src/runtime/distributed/worker/WorkerImpl.cpp +++ b/src/runtime/distributed/worker/WorkerImpl.cpp @@ -23,6 +23,7 @@ #include #include +#include #include "WorkerImpl.h" @@ -75,6 +76,12 @@ WorkerImpl::Status WorkerImpl::Compute(std::vector *outp // the FreeOps at the coordinator already. DaphneIrExecutor executor(false, cfg); + KernelCatalog & kc = executor.getUserConfig().kernelCatalog; + KernelCatalogParser kcp(executor.getContext()); + kcp.parseKernelCatalog("build/src/runtime/local/kernels/catalog.json", kc); + if(executor.getUserConfig().use_cuda) + kcp.parseKernelCatalog("build/src/runtime/local/kernels/CUDAcatalog.json", kc); + mlir::OwningOpRef module(mlir::parseSourceString(mlirCode, executor.getContext())); if (!module) { auto message = "Failed to parse source string.\n"; diff --git a/src/runtime/local/kernels/CMakeLists.txt b/src/runtime/local/kernels/CMakeLists.txt index 58841a563..173a3a0bb 100644 --- a/src/runtime/local/kernels/CMakeLists.txt +++ b/src/runtime/local/kernels/CMakeLists.txt @@ -22,9 +22,9 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) # The library of pre-compiled CUDA kernels if(USE_CUDA AND CMAKE_CUDA_COMPILER) add_custom_command( - OUTPUT ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/CUDAkernels.cpp + OUTPUT ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/CUDAkernels.cpp ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/CUDAcatalog.json COMMAND python3 ARGS genKernelInst.py kernels.json - ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/CUDAkernels.cpp CUDA + ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/CUDAkernels.cpp ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/CUDAcatalog.json CUDA MAIN_DEPENDENCY ${PROJECT_SOURCE_DIR}/src/runtime/local/kernels/kernels.json DEPENDS ${PROJECT_SOURCE_DIR}/src/runtime/local/kernels/genKernelInst.py WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/src/runtime/local/kernels/ @@ -74,8 +74,8 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER) endif() add_custom_command( - OUTPUT ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/kernels.cpp - COMMAND python3 ARGS genKernelInst.py kernels.json ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/kernels.cpp CPP + OUTPUT ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/kernels.cpp ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/catalog.json + COMMAND python3 ARGS genKernelInst.py kernels.json ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/kernels.cpp ${PROJECT_BINARY_DIR}/src/runtime/local/kernels/catalog.json CPP MAIN_DEPENDENCY ${PROJECT_SOURCE_DIR}/src/runtime/local/kernels/kernels.json DEPENDS ${PROJECT_SOURCE_DIR}/src/runtime/local/kernels/genKernelInst.py WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/src/runtime/local/kernels/ diff --git a/src/runtime/local/kernels/genKernelInst.py b/src/runtime/local/kernels/genKernelInst.py index d2b2c2a44..657e0c276 100755 --- a/src/runtime/local/kernels/genKernelInst.py +++ b/src/runtime/local/kernels/genKernelInst.py @@ -15,13 +15,18 @@ # limitations under the License. """ -Generates the C++ code for the pre-compiled kernels library. +Generates the C++ code for the pre-compiled kernels library as well as a JSON +file for the kernel catalog. This script generates C++ code instantiating the kernel templates that shall be part of a pre-compiled kernel library. Each kernel instantiation is wrapped by a shallow function that can be called from the JIT-compiled user program. An input JSON-file specifies which kernel shall be instantiated with which template arguments. + +Furthermore, a JSON file is generated that contains information about the +pre-compiled kernels. This file is used to populate the kernel catalog at +system start-up. """ # TODO Note that this script currently makes strong assumptions about the @@ -50,7 +55,7 @@ def toCppType(t): return t -def generateKernelInstantiation(kernelTemplateInfo, templateValues, opCodes, outFile, API): +def generateKernelInstantiation(kernelTemplateInfo, templateValues, opCodes, outFile, catalogEntries, API): # Extract some information. opName = kernelTemplateInfo["opName"] returnType = kernelTemplateInfo["returnType"] @@ -135,18 +140,20 @@ def generateFunction(opCode): # Obtain the name of the function to be generated from the opName by # removing suffices "Sca"/"Mat"/"Obj" (they are not required here), and # potentially by inserting the opCode into the name. + concreteOpName = opName + while concreteOpName[-3:] in ["Sca", "Mat", "Obj"]: + concreteOpName = concreteOpName[:-3] + concreteOpName = concreteOpName.replace("::", "_") + if opCode is not None: + opCodeWord = opCodeType[:-len("OpCode")] + concreteOpName = concreteOpName.replace(opCodeWord, opCode[0].upper() + opCode[1:].lower()) + concreteOpName = concreteOpName.replace(opCodeWord.lower(), opCode.lower()) + if API != "CPP": - funcName = API + "_" + opName + funcName = API + "_" + concreteOpName else: - funcName = "_" + opName - while funcName[-3:] in ["Sca", "Mat", "Obj"]: - funcName = funcName[:-3] - funcName = funcName.replace("::", "_") + funcName = "_" + concreteOpName - if opCode is not None: - opCodeWord = opCodeType[:-len("OpCode")] - funcName = funcName.replace(opCodeWord, opCode[0].upper() + opCode[1:].lower()) - funcName = funcName.replace(opCodeWord.lower(), opCode.lower()) # Signature of the function wrapping the kernel instantiation. outFile.write(INDENT + "void {}{}({}) {{\n".format( @@ -190,6 +197,25 @@ def generateFunction(opCode): )) outFile.write(INDENT + "}\n") + argTypes = [rtp["type"].replace(" **", "").replace(" *", "").replace("const ", "") for rtp in extendedRuntimeParams if not rtp["isOutput"]] + resTypes = [rtp["type"].replace(" **", "").replace(" *", "").replace("const ", "") for rtp in extendedRuntimeParams if rtp["isOutput"]] + + argTypesTmp = [] + for t in argTypes: + # TODO Don't hardcode these exceptions. + if t in ["void", "mlir::daphne::GroupEnum", "CompareOperation"]: + break + argTypesTmp.append(t) + argTypes = argTypesTmp + + catalogEntries.append({ + "opMnemonic": concreteOpName, + "kernelFuncName": funcName + typesForName, + "resTypes": resTypes, + "argTypes": argTypes, + "backend": API, + }) + # Generate the function(s). if opCodes is None: generateFunction(None) @@ -200,7 +226,7 @@ def generateFunction(opCode): def printHelp(): - print("Usage: python3 {} INPUT_SPEC_FILE OUTPUT_CPP_FILE API".format(sys.argv[0])) + print("Usage: python3 {} INPUT_SPEC_FILE OUTPUT_CPP_FILE OUTPUT_CATALOG_FILE API".format(sys.argv[0])) print(__doc__) @@ -208,21 +234,23 @@ def printHelp(): if len(sys.argv) == 2 and (sys.argv[1] == "-h" or sys.argv[1] == "--help"): printHelp() sys.exit(0) - elif len(sys.argv) != 4: + elif len(sys.argv) != 5: print("Wrong number of arguments.") print() printHelp() sys.exit(1) # Parse arguments. - inFilePath = sys.argv[1] - outFilePath = sys.argv[2] - API = sys.argv[3] + inSpecPath = sys.argv[1] + outCppPath = sys.argv[2] + outCatalogPath = sys.argv[3] + API = sys.argv[4] ops_inst_str = "" header_str = "" + catalog_entries = [] # Load the specification (which kernel template shall be instantiated # with which template arguments) from a JSON-file. - with open(inFilePath, "r") as inFile: + with open(inSpecPath, "r") as inFile: kernelsInfo = json.load(inFile) for kernelInfo in kernelsInfo: @@ -250,7 +278,7 @@ def printHelp(): outBuf = io.StringIO() for instantiation in api["instantiations"]: generateKernelInstantiation(kernelTemplateInfo, instantiation, - api.get("opCodes", None), outBuf, API) + api.get("opCodes", None), outBuf, catalog_entries, API) ops_inst_str += outBuf.getvalue() else: if API == "CPP": @@ -265,14 +293,19 @@ def printHelp(): opCodes = kernelInfo.get("opCodes", None) outBuf = io.StringIO() for instantiation in kernelInfo["instantiations"]: - generateKernelInstantiation(kernelTemplateInfo, instantiation, opCodes, outBuf, API) + generateKernelInstantiation(kernelTemplateInfo, instantiation, opCodes, outBuf, catalog_entries, API) ops_inst_str += outBuf.getvalue() - with open(outFilePath, "w") as outFile: + # Store the C++ code of the kernel instantiations in a CPP-file. + with open(outCppPath, "w") as outFile: outFile.write("// This file was generated by {}. Don't edit manually!\n\n".format(sys.argv[0])) outFile.write("#include \n") outFile.write(header_str) outFile.write("\nextern \"C\" {\n") outFile.write(ops_inst_str) outFile.write("}\n") + + # Store the information on the kernels in a JSON-file. + with open(outCatalogPath, "w") as outCatalog: + outCatalog.write(json.dumps(catalog_entries, indent=2)) \ No newline at end of file