Skip to content

Commit

Permalink
Initial support for kernel hints in DaphneDSL.
Browse files Browse the repository at this point in the history
- Expert users can optionally provide a hint on which concrete pre-compiled kernel function to use for a particular operation.
- So far, this is only supported for DaphneDSL built-in functions.
- Added a few script-level test cases.
- Updated the DaphneDSL language reference.
- The concrete syntax may be changed in the future.
- As a side note: DaphneDSLBuiltins::build() should invoke getOperation() on ops with zero results before returning to allow assigning kernel hints in an op-agnostic way.
  • Loading branch information
pdamme committed Apr 22, 2024
1 parent 894f563 commit 9eb00f1
Show file tree
Hide file tree
Showing 17 changed files with 262 additions and 62 deletions.
21 changes: 21 additions & 0 deletions doc/DaphneDSL/LanguageRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,27 @@ At call sites, a value of any type, or any value type, can be passed to an untyp
As a consequence, an untyped function is compiled and specialized on demand according to the types at a call site.
Consistently, the types of untyped return values are infered from the parameter types and operations.

## Compiler Hints

One of DAPHNE's strengths is its (WIP) ability to make various decisions on its own, e.g., regarding physical data representation (such as dense/sparse), physical operators (kernels), and data/operator placement (such as local/distributed, CPU/GPU/FPGA, computational storage).
However, expert users may optionally provide hints to influence compiler decisions.
This feature is useful for experimentation and in the context of DAPHNE's extensibility.
For instance, a user could force the use of a certain custom kernel at a certain point in a larger DaphneDSL script to measure the impact of that custom kernel, even if the DAPHNE compiler would normally not choose that kernel in that situation.

*The support for compiler hints is still experimental and it is currently not guaranteed that the DAPHNE compiler respects these hints.*

### Kernel Hints

Users can provide hints on the physical kernel that should be used for a specific occurrence of a DaphneDSL operation.
So far, kernel hints are only supported for DaphneDSL built-in functions.
Here, the name of the pre-compiled kernel function can optionally be attached to the name of the built-in function, separated by `::`.

*Examples:*

```r
res = sum::my_custom_sum_kernel(X);
```

## Example Scripts

A few example DaphneDSL scripts can be found in:
Expand Down
131 changes: 80 additions & 51 deletions src/compiler/lowering/RewriteToCallKernelOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,60 +444,89 @@ namespace
const std::string opMnemonic = op->getName().stripDialect().data();
std::vector<KernelInfo> kernelInfos = kc.getKernelInfos(opMnemonic);

if(kernelInfos.empty())
throw CompilerUtils::makeError(loc, "no kernels registered for operation `" + opMnemonic + "`");

std::string backend;
if(op->hasAttr("cuda_device"))
backend = "CUDA";
else if(op->hasAttr("fpgaopencl_device"))
backend = "FPGAOPENCL";
else
backend = "CPP";

const size_t numArgs = lookupArgTys.size();
const size_t numRess = lookupResTys.size();
int chosenKernelIdx = -1;
for(size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) {
auto ki = kernelInfos[i];
if(ki.backend != backend)
continue;
if(numArgs != ki.argTypes.size())
continue;
if(numRess != ki.resTypes.size())
continue;

bool mismatch = false;
for(size_t i = 0; i < numArgs && !mismatch; i++)
if(lookupArgTys[i] != ki.argTypes[i])
mismatch = true;
for(size_t i = 0; i < numRess && !mismatch; i++)
if(lookupResTys[i] != ki.resTypes[i])
mismatch = true;
if(!mismatch)
chosenKernelIdx = i;
std::string libPath;
std::string kernelFuncName;
// TODO Don't hardcode the attribute name, put it in a central place.
if(op->hasAttr("kernel_hint")) {
// The operation has a kernel hint. Lower to the hinted kernel if possible.

// TODO Check if the attribute has the right type.
kernelFuncName = op->getAttrOfType<mlir::StringAttr>("kernel_hint").getValue().str();
bool found = false;
for(size_t i = 0; i < kernelInfos.size() && !found; i++) {
auto ki = kernelInfos[i];
if(ki.kernelFuncName == kernelFuncName) {
libPath = ki.libPath;
found = true;
}
}
if(!found)
throw CompilerUtils::makeError(
loc,
"no kernel found for operation `" + opMnemonic +
"` with hinted name `" + kernelFuncName + "`"
);
}
if(chosenKernelIdx == -1) {
std::stringstream s;
s << "no kernel for operation `" << opMnemonic
<< "` available for the required input types `(";
for(size_t i = 0; i < numArgs; i++) {
s << lookupArgTys[i];
if(i < numArgs - 1)
s << ", ";
else {
// The operation does not have a kernel hint. Search for a kernel
// for this operation and the given result/argument types and backend.

if(kernelInfos.empty())
throw CompilerUtils::makeError(loc, "no kernels registered for operation `" + opMnemonic + "`");

std::string backend;
if(op->hasAttr("cuda_device"))
backend = "CUDA";
else if(op->hasAttr("fpgaopencl_device"))
backend = "FPGAOPENCL";
else
backend = "CPP";

const size_t numArgs = lookupArgTys.size();
const size_t numRess = lookupResTys.size();
int chosenKernelIdx = -1;
for(size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) {
auto ki = kernelInfos[i];
if(ki.backend != backend)
continue;
if(numArgs != ki.argTypes.size())
continue;
if(numRess != ki.resTypes.size())
continue;

bool mismatch = false;
for(size_t i = 0; i < numArgs && !mismatch; i++)
if(lookupArgTys[i] != ki.argTypes[i])
mismatch = true;
for(size_t i = 0; i < numRess && !mismatch; i++)
if(lookupResTys[i] != ki.resTypes[i])
mismatch = true;
if(!mismatch)
chosenKernelIdx = i;
}
s << + ")` and output types `(";
for(size_t i = 0; i < numRess; i++) {
s << lookupResTys[i];
if(i < numRess - 1)
s << ", ";
if(chosenKernelIdx == -1) {
std::stringstream s;
s << "no kernel for operation `" << opMnemonic
<< "` available for the required input types `(";
for(size_t i = 0; i < numArgs; i++) {
s << lookupArgTys[i];
if(i < numArgs - 1)
s << ", ";
}
s << + ")` and output types `(";
for(size_t i = 0; i < numRess; i++) {
s << lookupResTys[i];
if(i < numRess - 1)
s << ", ";
}
s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl;
kc.dump(opMnemonic, s);
throw CompilerUtils::makeError(loc, s.str());
}
s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl;
kc.dump(opMnemonic, s);
throw CompilerUtils::makeError(loc, s.str());
KernelInfo chosenKI = kernelInfos[chosenKernelIdx];
libPath = chosenKI.libPath;
kernelFuncName = chosenKI.kernelFuncName;
}
KernelInfo chosenKI = kernelInfos[chosenKernelIdx];
std::string libPath = chosenKI.libPath;

// *****************************************************************************
// Create the CallKernelOp
Expand All @@ -510,7 +539,7 @@ namespace
// Create a CallKernelOp for the kernel function to call and return success().
auto kernel = rewriter.create<daphne::CallKernelOp>(
loc,
chosenKI.kernelFuncName,
kernelFuncName,
kernelArgs,
opResTys
);
Expand Down
11 changes: 5 additions & 6 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,6 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f

if( func == "eigen" ) {
checkNumArgsExact(loc, func, numArgs, 1);
//TODO JIT-Engine invocation failed: Failed to materialize symbols
return builder.create<EigenOp>(loc,
args[0].getType(), args[0].getType(), args[0]).getResults();
}
Expand Down Expand Up @@ -888,7 +887,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
loc,
builder.getStringAttr(viewName),
view
);
).getOperation();
}

// --------------------------------------------------------------------
Expand Down Expand Up @@ -1033,7 +1032,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
: utils.castBoolIf(args[2]);
return builder.create<PrintOp>(
loc, arg, newline, err
);
).getOperation();
}

if (func == "readMatrix") {
Expand All @@ -1054,7 +1053,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
checkNumArgsExact(loc, func, numArgs, 2);
mlir::Value arg = args[0];
mlir::Value filename = args[1];
return builder.create<WriteOp>(loc, arg, filename);
return builder.create<WriteOp>(loc, arg, filename).getOperation();
}
if(func == "receiveFromNumpy") {
checkNumArgsExact(loc, func, numArgs, 5);
Expand Down Expand Up @@ -1097,7 +1096,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
if(func == "saveDaphneLibResult") {
checkNumArgsExact(loc, func, numArgs, 1);
mlir::Value arg = args[0];
return builder.create<SaveDaphneLibResultOp>(loc, arg);
return builder.create<SaveDaphneLibResultOp>(loc, arg).getOperation();
}

// --------------------------------------------------------------------
Expand Down Expand Up @@ -1131,7 +1130,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
mlir::Value fileOrTarget = args[0];
return builder.create<CloseOp>(
loc, fileOrTarget
);
).getOperation();
}
if(func == "readCsv") {
checkNumArgsExact(loc, func, numArgs, 4);
Expand Down
2 changes: 1 addition & 1 deletion src/parser/daphnedsl/DaphneDSLGrammar.g4
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ expr:
| '$' arg=IDENTIFIER # argExpr
| (( IDENTIFIER '.' )* IDENTIFIER) # identifierExpr
| '(' expr ')' # paranthesesExpr
| (( IDENTIFIER '.' )* IDENTIFIER) '(' (expr (',' expr)*)? ')' # callExpr
| ( ns=IDENTIFIER '.' )* func=IDENTIFIER ('::' kernel=IDENTIFIER)? '(' (expr (',' expr)*)? ')' # callExpr
| KW_AS (('.' DATA_TYPE) | ('.' VALUE_TYPE) | ('.' DATA_TYPE '<' VALUE_TYPE '>')) '(' expr ')' # castExpr
| obj=expr '[[' (rows=expr)? ',' (cols=expr)? ']]' # rightIdxFilterExpr
| obj=expr idx=indexing # rightIdxExtractExpr
Expand Down
55 changes: 51 additions & 4 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,12 +949,13 @@ antlrcpp::Any DaphneDSLVisitor::handleMapOpCall(DaphneDSLGrammarParser::CallExpr
return builtins.build(loc, func, args);
}


antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprContext * ctx) {
std::string func;
const auto& identifierVec = ctx->IDENTIFIER();
for(size_t s = 0; s < identifierVec.size(); s++)
func += (s < identifierVec.size() - 1) ? identifierVec[s]->getText() + '.' : identifierVec[s]->getText();
bool hasKernelHint = ctx->kernel != nullptr;
for(size_t s = 0; s < identifierVec.size() - 1 - hasKernelHint; s++)
func += identifierVec[s]->getText() + '.';
func += ctx->func->getText();
mlir::Location loc = utils.getLoc(ctx->start);

if (func == "map")
Expand All @@ -968,6 +969,12 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo
auto maybeUDF = findMatchingUDF(func, args_vec, loc);

if (maybeUDF) {
if(hasKernelHint)
throw CompilerUtils::makeError(
loc,
"kernel hints are not supported for calls to user-defined functions"
);

auto funcTy = maybeUDF->getFunctionType();
auto co = builder
.create<mlir::daphne::GenericCallOp>(loc,
Expand All @@ -984,7 +991,47 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo
}

// Create DaphneIR operation for the built-in function.
return builtins.build(loc, func, args_vec);
antlrcpp::Any res = builtins.build(loc, func, args_vec);

if(hasKernelHint) {
std::string kernel = ctx->kernel->getText();

// We deliberately don't check if the specified kernel
// is registered for the created kind of operation,
// since this is checked in RewriteToCallKernelOpPass.

mlir::Operation* op;
if(res.is<mlir::Operation*>()) // DaphneIR ops with exactly zero results
op = res.as<mlir::Operation*>();
else if(res.is<mlir::Value>()) // DaphneIR ops with exactly one result
op = res.as<mlir::Value>().getDefiningOp();
else if(res.is<mlir::ResultRange>()) { // DaphneIR ops with more than one results
auto rr = res.as<mlir::ResultRange>();
op = rr[0].getDefiningOp();
// Normally, all values in the ResultRange should be results of
// the same op, but we check it nevertheless, just to be sure.
for(size_t i = 1; i < rr.size(); i++)
if(rr[i].getDefiningOp() != op)
throw CompilerUtils::makeError(
loc,
"the given kernel hint `" + kernel +
"` cannot be applied since the DaphneIR operation created for the built-in function `" +
func + "` is ambiguous"
);
}
else // unexpected case
throw CompilerUtils::makeError(
loc,
"the given kernel hint `" + kernel +
"` cannot be applied since the DaphneIR operation created for the built-in function `" +
func + "` was not returned in a supported way"
);

// TODO Don't hardcode the attribute name.
op->setAttr("kernel_hint", builder.getStringAttr(kernel));
}

return res;
}

antlrcpp::Any DaphneDSLVisitor::visitCastExpr(DaphneDSLGrammarParser::CastExprContext * ctx) {
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(TEST_SOURCES
api/cli/expressions/CastTest.cpp
api/cli/expressions/CondTest.cpp
api/cli/expressions/MatrixLiteralTest.cpp
api/cli/extensibility/HintTest.cpp
api/cli/functions/FunctionsTest.cpp
api/cli/functions/RecursiveFunctionsTest.cpp
api/cli/io/ReadTest.cpp
Expand Down
64 changes: 64 additions & 0 deletions test/api/cli/extensibility/HintTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2024 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <api/cli/Utils.h>

#include <tags.h>

#include <catch.hpp>

#include <sstream>
#include <string>

const std::string dirPath = "test/api/cli/extensibility/";

#define MAKE_SUCCESS_TEST_CASE(name, count) \
TEST_CASE(name ", success", TAG_EXTENSIBILITY) { \
for(unsigned i = 1; i <= count; i++) { \
DYNAMIC_SECTION(name "_success_" << i << ".daphne") { \
compareDaphneToRefSimple(dirPath, name "_success", i); \
} \
} \
}

#define MAKE_FAILURE_TEST_CASE(name, count) \
TEST_CASE(name ", failure", TAG_EXTENSIBILITY) { \
for(unsigned i = 1; i <= count; i++) { \
DYNAMIC_SECTION(name "_failure_" << i << ".daphne") { \
checkDaphneFailsSimple(dirPath, name "_failure", i); \
} \
} \
}

#define MAKE_IR_TEST_CASE(idx, kernelName) \
TEST_CASE("hint_kernel_success_" #idx ".daphne, hint presence", TAG_EXTENSIBILITY) { \
std::stringstream out; \
std::stringstream err; \
int status = runDaphne(out, err, "--explain", "parsing_simplified", (dirPath + "hint_kernel_success_" #idx ".daphne").c_str()); \
CHECK(status == StatusCode::SUCCESS); \
CHECK_THAT(err.str(), Catch::Contains("kernel_hint = \"" kernelName "\"")); \
}

// Check if DAPHNE fails when expected.
MAKE_FAILURE_TEST_CASE("hint_kernel", 3)

// Check if DAPHNE terminates normally when expected and produces the expected output.
MAKE_SUCCESS_TEST_CASE("hint_kernel", 3)

// Check if DAPHNE terminates normally when expected and if the IR really contains the kernel hint.
MAKE_IR_TEST_CASE(1, "_print__int64_t__bool__bool");
MAKE_IR_TEST_CASE(2, "_sumAll__int64_t__DenseMatrix_int64_t");
MAKE_IR_TEST_CASE(3, "_recode__DenseMatrix_int64_t__DenseMatrix_double__DenseMatrix_double__bool");
4 changes: 4 additions & 0 deletions test/api/cli/extensibility/hint_kernel_failure_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Hint to use a non-existing pre-compiled kernel.

res = sum::nonExistingKernel([42]);
print(res);
Loading

0 comments on commit 9eb00f1

Please sign in to comment.