diff --git a/doc/DaphneDSL/LanguageRef.md b/doc/DaphneDSL/LanguageRef.md index 337f63c49..ded48a280 100644 --- a/doc/DaphneDSL/LanguageRef.md +++ b/doc/DaphneDSL/LanguageRef.md @@ -757,6 +757,27 @@ At call sites, a value of any type, or any value type, can be passed to an untyp As a consequence, an untyped function is compiled and specialized on demand according to the types at a call site. Consistently, the types of untyped return values are infered from the parameter types and operations. +## Compiler Hints + +One of DAPHNE's strengths is its (WIP) ability to make various decisions on its own, e.g., regarding physical data representation (such as dense/sparse), physical operators (kernels), and data/operator placement (such as local/distributed, CPU/GPU/FPGA, computational storage). +However, expert users may optionally provide hints to influence compiler decisions. +This feature is useful for experimentation and in the context of DAPHNE's extensibility. +For instance, a user could force the use of a certain custom kernel at a certain point in a larger DaphneDSL script to measure the impact of that custom kernel, even if the DAPHNE compiler would normally not choose that kernel in that situation. + +*The support for compiler hints is still experimental and it is currently not guaranteed that the DAPHNE compiler respects these hints.* + +### Kernel Hints + +Users can provide hints on the physical kernel that should be used for a specific occurrence of a DaphneDSL operation. +So far, kernel hints are only supported for DaphneDSL built-in functions. +Here, the name of the pre-compiled kernel function can optionally be attached to the name of the built-in function, separated by `::`. + +*Examples:* + +```r +res = sum::my_custom_sum_kernel(X); +``` + ## Example Scripts A few example DaphneDSL scripts can be found in: diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index 0a45342c0..669c5cb2f 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -444,60 +444,89 @@ namespace const std::string opMnemonic = op->getName().stripDialect().data(); std::vector kernelInfos = kc.getKernelInfos(opMnemonic); - if(kernelInfos.empty()) - throw CompilerUtils::makeError(loc, "no kernels registered for operation `" + opMnemonic + "`"); - - std::string backend; - if(op->hasAttr("cuda_device")) - backend = "CUDA"; - else if(op->hasAttr("fpgaopencl_device")) - backend = "FPGAOPENCL"; - else - backend = "CPP"; - - const size_t numArgs = lookupArgTys.size(); - const size_t numRess = lookupResTys.size(); - int chosenKernelIdx = -1; - for(size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) { - auto ki = kernelInfos[i]; - if(ki.backend != backend) - continue; - if(numArgs != ki.argTypes.size()) - continue; - if(numRess != ki.resTypes.size()) - continue; - - bool mismatch = false; - for(size_t i = 0; i < numArgs && !mismatch; i++) - if(lookupArgTys[i] != ki.argTypes[i]) - mismatch = true; - for(size_t i = 0; i < numRess && !mismatch; i++) - if(lookupResTys[i] != ki.resTypes[i]) - mismatch = true; - if(!mismatch) - chosenKernelIdx = i; + std::string libPath; + std::string kernelFuncName; + // TODO Don't hardcode the attribute name, put it in a central place. + if(op->hasAttr("kernel_hint")) { + // The operation has a kernel hint. Lower to the hinted kernel if possible. + + // TODO Check if the attribute has the right type. + kernelFuncName = op->getAttrOfType("kernel_hint").getValue().str(); + bool found = false; + for(size_t i = 0; i < kernelInfos.size() && !found; i++) { + auto ki = kernelInfos[i]; + if(ki.kernelFuncName == kernelFuncName) { + libPath = ki.libPath; + found = true; + } + } + if(!found) + throw CompilerUtils::makeError( + loc, + "no kernel found for operation `" + opMnemonic + + "` with hinted name `" + kernelFuncName + "`" + ); } - if(chosenKernelIdx == -1) { - std::stringstream s; - s << "no kernel for operation `" << opMnemonic - << "` available for the required input types `("; - for(size_t i = 0; i < numArgs; i++) { - s << lookupArgTys[i]; - if(i < numArgs - 1) - s << ", "; + else { + // The operation does not have a kernel hint. Search for a kernel + // for this operation and the given result/argument types and backend. + + if(kernelInfos.empty()) + throw CompilerUtils::makeError(loc, "no kernels registered for operation `" + opMnemonic + "`"); + + std::string backend; + if(op->hasAttr("cuda_device")) + backend = "CUDA"; + else if(op->hasAttr("fpgaopencl_device")) + backend = "FPGAOPENCL"; + else + backend = "CPP"; + + const size_t numArgs = lookupArgTys.size(); + const size_t numRess = lookupResTys.size(); + int chosenKernelIdx = -1; + for(size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) { + auto ki = kernelInfos[i]; + if(ki.backend != backend) + continue; + if(numArgs != ki.argTypes.size()) + continue; + if(numRess != ki.resTypes.size()) + continue; + + bool mismatch = false; + for(size_t i = 0; i < numArgs && !mismatch; i++) + if(lookupArgTys[i] != ki.argTypes[i]) + mismatch = true; + for(size_t i = 0; i < numRess && !mismatch; i++) + if(lookupResTys[i] != ki.resTypes[i]) + mismatch = true; + if(!mismatch) + chosenKernelIdx = i; } - s << + ")` and output types `("; - for(size_t i = 0; i < numRess; i++) { - s << lookupResTys[i]; - if(i < numRess - 1) - s << ", "; + if(chosenKernelIdx == -1) { + std::stringstream s; + s << "no kernel for operation `" << opMnemonic + << "` available for the required input types `("; + for(size_t i = 0; i < numArgs; i++) { + s << lookupArgTys[i]; + if(i < numArgs - 1) + s << ", "; + } + s << + ")` and output types `("; + for(size_t i = 0; i < numRess; i++) { + s << lookupResTys[i]; + if(i < numRess - 1) + s << ", "; + } + s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl; + kc.dump(opMnemonic, s); + throw CompilerUtils::makeError(loc, s.str()); } - s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl; - kc.dump(opMnemonic, s); - throw CompilerUtils::makeError(loc, s.str()); + KernelInfo chosenKI = kernelInfos[chosenKernelIdx]; + libPath = chosenKI.libPath; + kernelFuncName = chosenKI.kernelFuncName; } - KernelInfo chosenKI = kernelInfos[chosenKernelIdx]; - std::string libPath = chosenKI.libPath; // ***************************************************************************** // Create the CallKernelOp @@ -510,7 +539,7 @@ namespace // Create a CallKernelOp for the kernel function to call and return success(). auto kernel = rewriter.create( loc, - chosenKI.kernelFuncName, + kernelFuncName, kernelArgs, opResTys ); diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index cbddab71a..22bdaba38 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -716,7 +716,6 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f if( func == "eigen" ) { checkNumArgsExact(loc, func, numArgs, 1); - //TODO JIT-Engine invocation failed: Failed to materialize symbols return builder.create(loc, args[0].getType(), args[0].getType(), args[0]).getResults(); } @@ -888,7 +887,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f loc, builder.getStringAttr(viewName), view - ); + ).getOperation(); } // -------------------------------------------------------------------- @@ -1033,7 +1032,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f : utils.castBoolIf(args[2]); return builder.create( loc, arg, newline, err - ); + ).getOperation(); } if (func == "readMatrix") { @@ -1054,7 +1053,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f checkNumArgsExact(loc, func, numArgs, 2); mlir::Value arg = args[0]; mlir::Value filename = args[1]; - return builder.create(loc, arg, filename); + return builder.create(loc, arg, filename).getOperation(); } if(func == "receiveFromNumpy") { checkNumArgsExact(loc, func, numArgs, 5); @@ -1097,7 +1096,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f if(func == "saveDaphneLibResult") { checkNumArgsExact(loc, func, numArgs, 1); mlir::Value arg = args[0]; - return builder.create(loc, arg); + return builder.create(loc, arg).getOperation(); } // -------------------------------------------------------------------- @@ -1131,7 +1130,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f mlir::Value fileOrTarget = args[0]; return builder.create( loc, fileOrTarget - ); + ).getOperation(); } if(func == "readCsv") { checkNumArgsExact(loc, func, numArgs, 4); diff --git a/src/parser/daphnedsl/DaphneDSLGrammar.g4 b/src/parser/daphnedsl/DaphneDSLGrammar.g4 index 00c73bf75..0ab7fa8e7 100644 --- a/src/parser/daphnedsl/DaphneDSLGrammar.g4 +++ b/src/parser/daphnedsl/DaphneDSLGrammar.g4 @@ -83,7 +83,7 @@ expr: | '$' arg=IDENTIFIER # argExpr | (( IDENTIFIER '.' )* IDENTIFIER) # identifierExpr | '(' expr ')' # paranthesesExpr - | (( IDENTIFIER '.' )* IDENTIFIER) '(' (expr (',' expr)*)? ')' # callExpr + | ( ns=IDENTIFIER '.' )* func=IDENTIFIER ('::' kernel=IDENTIFIER)? '(' (expr (',' expr)*)? ')' # callExpr | KW_AS (('.' DATA_TYPE) | ('.' VALUE_TYPE) | ('.' DATA_TYPE '<' VALUE_TYPE '>')) '(' expr ')' # castExpr | obj=expr '[[' (rows=expr)? ',' (cols=expr)? ']]' # rightIdxFilterExpr | obj=expr idx=indexing # rightIdxExtractExpr diff --git a/src/parser/daphnedsl/DaphneDSLVisitor.cpp b/src/parser/daphnedsl/DaphneDSLVisitor.cpp index 6a0ccf288..2703d931e 100644 --- a/src/parser/daphnedsl/DaphneDSLVisitor.cpp +++ b/src/parser/daphnedsl/DaphneDSLVisitor.cpp @@ -949,12 +949,13 @@ antlrcpp::Any DaphneDSLVisitor::handleMapOpCall(DaphneDSLGrammarParser::CallExpr return builtins.build(loc, func, args); } - antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprContext * ctx) { std::string func; const auto& identifierVec = ctx->IDENTIFIER(); - for(size_t s = 0; s < identifierVec.size(); s++) - func += (s < identifierVec.size() - 1) ? identifierVec[s]->getText() + '.' : identifierVec[s]->getText(); + bool hasKernelHint = ctx->kernel != nullptr; + for(size_t s = 0; s < identifierVec.size() - 1 - hasKernelHint; s++) + func += identifierVec[s]->getText() + '.'; + func += ctx->func->getText(); mlir::Location loc = utils.getLoc(ctx->start); if (func == "map") @@ -968,6 +969,12 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo auto maybeUDF = findMatchingUDF(func, args_vec, loc); if (maybeUDF) { + if(hasKernelHint) + throw CompilerUtils::makeError( + loc, + "kernel hints are not supported for calls to user-defined functions" + ); + auto funcTy = maybeUDF->getFunctionType(); auto co = builder .create(loc, @@ -984,7 +991,47 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo } // Create DaphneIR operation for the built-in function. - return builtins.build(loc, func, args_vec); + antlrcpp::Any res = builtins.build(loc, func, args_vec); + + if(hasKernelHint) { + std::string kernel = ctx->kernel->getText(); + + // We deliberately don't check if the specified kernel + // is registered for the created kind of operation, + // since this is checked in RewriteToCallKernelOpPass. + + mlir::Operation* op; + if(res.is()) // DaphneIR ops with exactly zero results + op = res.as(); + else if(res.is()) // DaphneIR ops with exactly one result + op = res.as().getDefiningOp(); + else if(res.is()) { // DaphneIR ops with more than one results + auto rr = res.as(); + op = rr[0].getDefiningOp(); + // Normally, all values in the ResultRange should be results of + // the same op, but we check it nevertheless, just to be sure. + for(size_t i = 1; i < rr.size(); i++) + if(rr[i].getDefiningOp() != op) + throw CompilerUtils::makeError( + loc, + "the given kernel hint `" + kernel + + "` cannot be applied since the DaphneIR operation created for the built-in function `" + + func + "` is ambiguous" + ); + } + else // unexpected case + throw CompilerUtils::makeError( + loc, + "the given kernel hint `" + kernel + + "` cannot be applied since the DaphneIR operation created for the built-in function `" + + func + "` was not returned in a supported way" + ); + + // TODO Don't hardcode the attribute name. + op->setAttr("kernel_hint", builder.getStringAttr(kernel)); + } + + return res; } antlrcpp::Any DaphneDSLVisitor::visitCastExpr(DaphneDSLGrammarParser::CastExprContext * ctx) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 58aabba5f..94bb41cc4 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -27,6 +27,7 @@ set(TEST_SOURCES api/cli/expressions/CastTest.cpp api/cli/expressions/CondTest.cpp api/cli/expressions/MatrixLiteralTest.cpp + api/cli/extensibility/HintTest.cpp api/cli/functions/FunctionsTest.cpp api/cli/functions/RecursiveFunctionsTest.cpp api/cli/io/ReadTest.cpp diff --git a/test/api/cli/extensibility/HintTest.cpp b/test/api/cli/extensibility/HintTest.cpp new file mode 100644 index 000000000..5629dac0a --- /dev/null +++ b/test/api/cli/extensibility/HintTest.cpp @@ -0,0 +1,64 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include + +#include +#include + +const std::string dirPath = "test/api/cli/extensibility/"; + +#define MAKE_SUCCESS_TEST_CASE(name, count) \ + TEST_CASE(name ", success", TAG_EXTENSIBILITY) { \ + for(unsigned i = 1; i <= count; i++) { \ + DYNAMIC_SECTION(name "_success_" << i << ".daphne") { \ + compareDaphneToRefSimple(dirPath, name "_success", i); \ + } \ + } \ + } + +#define MAKE_FAILURE_TEST_CASE(name, count) \ + TEST_CASE(name ", failure", TAG_EXTENSIBILITY) { \ + for(unsigned i = 1; i <= count; i++) { \ + DYNAMIC_SECTION(name "_failure_" << i << ".daphne") { \ + checkDaphneFailsSimple(dirPath, name "_failure", i); \ + } \ + } \ + } + +#define MAKE_IR_TEST_CASE(idx, kernelName) \ + TEST_CASE("hint_kernel_success_" #idx ".daphne, hint presence", TAG_EXTENSIBILITY) { \ + std::stringstream out; \ + std::stringstream err; \ + int status = runDaphne(out, err, "--explain", "parsing_simplified", (dirPath + "hint_kernel_success_" #idx ".daphne").c_str()); \ + CHECK(status == StatusCode::SUCCESS); \ + CHECK_THAT(err.str(), Catch::Contains("kernel_hint = \"" kernelName "\"")); \ + } + +// Check if DAPHNE fails when expected. +MAKE_FAILURE_TEST_CASE("hint_kernel", 3) + +// Check if DAPHNE terminates normally when expected and produces the expected output. +MAKE_SUCCESS_TEST_CASE("hint_kernel", 3) + +// Check if DAPHNE terminates normally when expected and if the IR really contains the kernel hint. +MAKE_IR_TEST_CASE(1, "_print__int64_t__bool__bool"); +MAKE_IR_TEST_CASE(2, "_sumAll__int64_t__DenseMatrix_int64_t"); +MAKE_IR_TEST_CASE(3, "_recode__DenseMatrix_int64_t__DenseMatrix_double__DenseMatrix_double__bool"); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_failure_1.daphne b/test/api/cli/extensibility/hint_kernel_failure_1.daphne new file mode 100644 index 000000000..70f8ca900 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_failure_1.daphne @@ -0,0 +1,4 @@ +// Hint to use a non-existing pre-compiled kernel. + +res = sum::nonExistingKernel([42]); +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_failure_2.daphne b/test/api/cli/extensibility/hint_kernel_failure_2.daphne new file mode 100644 index 000000000..419a2a8bd --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_failure_2.daphne @@ -0,0 +1,4 @@ +// Hint to use an existing pre-compiled kernel of another operation. + +res = aggMin::_sumAll__int64_t__DenseMatrix_int64_t([42]); +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_failure_3.daphne b/test/api/cli/extensibility/hint_kernel_failure_3.daphne new file mode 100644 index 000000000..8795b34f1 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_failure_3.daphne @@ -0,0 +1,7 @@ +// Hint to use a pre-compiled kernel for a user-defined function. + +def hello() { + print("world"); +} + +hello::someKernel(); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_1.daphne b/test/api/cli/extensibility/hint_kernel_success_1.daphne new file mode 100644 index 000000000..bc235a592 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_1.daphne @@ -0,0 +1,3 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly zero results. + +print::_print__int64_t__bool__bool(42); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_1.txt b/test/api/cli/extensibility/hint_kernel_success_1.txt new file mode 100644 index 000000000..d81cc0710 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_1.txt @@ -0,0 +1 @@ +42 diff --git a/test/api/cli/extensibility/hint_kernel_success_2.daphne b/test/api/cli/extensibility/hint_kernel_success_2.daphne new file mode 100644 index 000000000..5bff4eb99 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_2.daphne @@ -0,0 +1,4 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with exactly one result. + +res = sum::_sumAll__int64_t__DenseMatrix_int64_t([21, 21]); +print(res); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_2.txt b/test/api/cli/extensibility/hint_kernel_success_2.txt new file mode 100644 index 000000000..d81cc0710 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_2.txt @@ -0,0 +1 @@ +42 diff --git a/test/api/cli/extensibility/hint_kernel_success_3.daphne b/test/api/cli/extensibility/hint_kernel_success_3.daphne new file mode 100644 index 000000000..f17fa1353 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_3.daphne @@ -0,0 +1,5 @@ +// Hint to use an existing pre-compiled kernel for a DaphneIR op with more than one result. + +codes, dict = recode::_recode__DenseMatrix_int64_t__DenseMatrix_double__DenseMatrix_double__bool([1.1, 3.3, 1.1, 2.2], false); +print(codes); +print(dict); \ No newline at end of file diff --git a/test/api/cli/extensibility/hint_kernel_success_3.txt b/test/api/cli/extensibility/hint_kernel_success_3.txt new file mode 100644 index 000000000..e971c8cc1 --- /dev/null +++ b/test/api/cli/extensibility/hint_kernel_success_3.txt @@ -0,0 +1,9 @@ +DenseMatrix(4x1, int64_t) +0 +1 +0 +2 +DenseMatrix(3x1, double) +1.1 +3.3 +2.2 diff --git a/test/tags.h b/test/tags.h index 14c490cbc..9af0d1c89 100644 --- a/test/tags.h +++ b/test/tags.h @@ -29,6 +29,7 @@ #define TAG_CONTROLFLOW "[controlflow]" #define TAG_DATASTRUCTURES "[datastructures]" #define TAG_DISTRIBUTED "[distributed]" +#define TAG_EXTENSIBILITY "[extensibility]" #define TAG_MATRIX_LITERAL "[matrixliterals]" #define TAG_TERNARY "[ternary]" #define TAG_FUNCTIONS "[functions]"