Skip to content

Commit

Permalink
[DAPHNE-#595] Fix crash when using UDF with zero results where one va…
Browse files Browse the repository at this point in the history
…lue is needed (#605)

The core issue is that DaphneDSL UDFs without return value internally returned an invalid mlir::Value. Upon doing anything with this value (i.e., when using it in a context where exactly one value is expected, e.g., in expressions or in assignments), the program would crash (e.g., accessing any fields or calling any methods). Now we return nullptr as the result of a UDF with zero return values in visitCallExpr(). This nullptr is detected in valueOrError(), a helper function we anyway call everywhere in the DaphneDSL parser when a single value is needed, and causes an exception to be thrown

The error message is not informative yet ("[error]: While parsing: something was expected to be an mlir::Value, but it was none"), but this will be improved in a follow-up commit.

Added several script-level test cases, which check that DAPHNE does not accept UDFs with zero or more than one return value in places where exactly one value is expected. All of these tests on a UDF with zero return values used to crash DAPHNE with a segfault (#595 was not specific to call expressions). Actually, these are just a few examples, ideally we should have test cases for all uses of expr in the DaphneDSL grammar.

Closes #595.


---------

Co-authored-by: Patrick Damme <[email protected]>
  • Loading branch information
corepointer and pdamme authored Jul 11, 2024
1 parent 5a47688 commit 3e56ed5
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 47 deletions.
41 changes: 17 additions & 24 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ antlrcpp::Any DaphneDSLVisitor::visitImportStatement(DaphneDSLGrammarParser::Imp
symbolTable.put(origScope);

importedFiles = origImportedFiles;

for(std::pair<std::string, mlir::func::FuncOp> funcSymbol : functionsSymbolMap)
if(funcSymbol.first.find('.') == std::string::npos)
origFuncMap.insert({finalPrefix + funcSymbol.first, funcSymbol.second});
Expand Down Expand Up @@ -969,7 +969,7 @@ antlrcpp::Any DaphneDSLVisitor::handleMapOpCall(DaphneDSLGrammarParser::CallExpr
"called 'handleMapOpCall' for function "
+ func + " instead of 'map'"
);

if (ctx->expr().size() != 2) {
throw ErrorHandler::compilerError(loc, "DSLVisitor",
"built-in function 'map' expects exactly 2 argument(s), but got " +
Expand Down Expand Up @@ -1018,7 +1018,6 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo
args_vec.push_back(utils.valueOrError(visit(ctx->expr(i))));

auto maybeUDF = findMatchingUDF(func, args_vec, loc);

if (maybeUDF) {
if(hasKernelHint)
throw ErrorHandler::compilerError(
Expand All @@ -1028,26 +1027,21 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo
);

auto funcTy = maybeUDF->getFunctionType();
auto co = builder
.create<mlir::daphne::GenericCallOp>(loc,
maybeUDF->getSymName(),
args_vec,
funcTy.getResults());
auto co = builder.create<mlir::daphne::GenericCallOp>(loc, maybeUDF->getSymName(), args_vec, funcTy.getResults());
if(funcTy.getNumResults() > 1)
return co.getResults();
else
// If the UDF has no return values, the value returned here
// is invalid. But that seems to be okay, since it is never
// used as a mlir::Value in that case.
else if(funcTy.getNumResults() == 1)
return co.getResult(0);
else
return nullptr;
}

// Create DaphneIR operation for the built-in function.
antlrcpp::Any res = builtins.build(loc, func, args_vec);

if(hasKernelHint) {
std::string kernel = ctx->kernel->getText();

// We deliberately don't check if the specified kernel
// is registered for the created kind of operation,
// since this is checked in RewriteToCallKernelOpPass.
Expand Down Expand Up @@ -1487,7 +1481,7 @@ mlir::Value DaphneDSLVisitor::buildColMatrixFromValues(mlir::Location loc, const
// Maybe later these InsertOps can be fused into a single one
// or replaced with InsertOps that support scalar input.
result = static_cast<mlir::Value>(builder.create<mlir::daphne::InsertRowOp>(loc, utils.matrixOf(matrixVt),
result,
result,
ins,
builder.create<mlir::daphne::ConstantOp>(loc, idx),
builder.create<mlir::daphne::ConstantOp>(loc, idx+1)));
Expand All @@ -1496,7 +1490,7 @@ mlir::Value DaphneDSLVisitor::buildColMatrixFromValues(mlir::Location loc, const
return result;
}

antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::MatrixLiteralExprContext * ctx) {
antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::MatrixLiteralExprContext * ctx) {
mlir::Location loc = utils.getLoc(ctx->start);
mlir::Value rows;
mlir::Value cols;
Expand All @@ -1511,7 +1505,7 @@ antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::M
rows = builder.create<mlir::daphne::ConstantOp>(loc, static_cast<size_t>(ctx->expr().size()));
}
else {
numMatElems = (ctx->rows && ctx->cols) ? ctx->expr().size() - 2 : ctx->expr().size() - 1;
numMatElems = (ctx->rows && ctx->cols) ? ctx->expr().size() - 2 : ctx->expr().size() - 1;
if (ctx->cols && ctx->rows) {
cols = utils.valueOrError(visit(ctx->cols));
rows = utils.valueOrError(visit(ctx->rows));
Expand Down Expand Up @@ -1540,7 +1534,7 @@ antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::M
values.emplace_back(currentValue);
valueTypes.emplace_back(currentValue.getType());
}

mlir::Type valueType = mostGeneralVt(valueTypes);
mlir::Value colMatrix;

Expand Down Expand Up @@ -1581,7 +1575,7 @@ antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::M
else {
throw ErrorHandler::compilerError(loc, "DSLVisitor", "matrix literal of invalid value type");
}

// TODO: omit ReshapeOp if rows=1 (not always known at parse-time)
mlir::Value result = static_cast<mlir::Value>(builder.create<mlir::daphne::ReshapeOp>(loc, utils.matrixOf(valueType), colMatrix, rows, cols));

Expand All @@ -1593,7 +1587,7 @@ antlrcpp::Any DaphneDSLVisitor::visitColMajorFrameLiteralExpr(DaphneDSLGrammarPa

size_t labelCount = ctx->labels.size();
size_t colCount = ctx->cols.size();

if (labelCount != colCount)
throw ErrorHandler::compilerError(loc, "DSLVisitor", "frame literals must have an equal number of column labels and column matrices");

Expand All @@ -1619,7 +1613,7 @@ antlrcpp::Any DaphneDSLVisitor::visitColMajorFrameLiteralExpr(DaphneDSLGrammarPa
}

mlir::Type frameColTypes = mlir::daphne::FrameType::get(builder.getContext(), columnMatElemType);

mlir::Value result = static_cast<mlir::Value>(builder.create<mlir::daphne::CreateFrameOp>(loc, frameColTypes, columnMatrices, parsedLabels));

return result;
Expand Down Expand Up @@ -1722,7 +1716,7 @@ antlrcpp::Any DaphneDSLVisitor::visitRowMajorFrameLiteralExpr(DaphneDSLGrammarPa
}

mlir::Type frameColTypes = mlir::daphne::FrameType::get(builder.getContext(), colTypes);

mlir::Value result = static_cast<mlir::Value>(builder.create<mlir::daphne::CreateFrameOp>(loc, frameColTypes, colValues, parsedLabels));

return result;
Expand Down Expand Up @@ -2133,11 +2127,10 @@ antlrcpp::Any DaphneDSLVisitor::visitFunctionStatement(DaphneDSLGrammarParser::F
visitBlockStatement(ctx->bodyStmt);

rectifyEarlyReturns(funcBlock);
if(funcBlock->getOperations().empty()
|| !funcBlock->getOperations().back().hasTrait<mlir::OpTrait::IsTerminator>()) {

if(funcBlock->getOperations().empty() || !funcBlock->getOperations().back().hasTrait<mlir::OpTrait::IsTerminator>()) {
builder.create<mlir::daphne::ReturnOp>(utils.getLoc(ctx->stop));
}

auto terminator = funcBlock->getTerminator();
auto returnOpTypes = terminator->getOperandTypes();
if(!functionOperation) {
Expand Down
2 changes: 1 addition & 1 deletion test/api/cli/functions/FunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ MAKE_TEST_CASE("typed", 5)
MAKE_TEST_CASE("untyped", 4)
MAKE_TEST_CASE("mixtyped", 2)
MAKE_TEST_CASE("early_return", 3)
MAKE_INVALID_TEST_CASE("invalid_parser", 11, StatusCode::PARSER_ERROR)
MAKE_INVALID_TEST_CASE("invalid_parser", 25, StatusCode::PARSER_ERROR)
11 changes: 6 additions & 5 deletions test/api/cli/functions/invalid_parser_10.daphne
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// invalid number of results (expecting 1, getting 2)
// invalid number of results when used in assignment statement (expecting 2, getting 3)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
def f(a:si64) -> si64, si64, si64 {
return a + 1, a + 2, a + 3;
}

x = f(123);
print(x);
x, y = f(123);
print(x);
print(y);
16 changes: 5 additions & 11 deletions test/api/cli/functions/invalid_parser_11.daphne
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
// overloading typed functions with different number of results, but same arguments
// invalid number of results when used in assignment statement (expecting 1, getting 0)

def f(a:si64) -> si64 {
return a + 1;
}
def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
def f(a:si64) {
return;
}

x = f(100);
print(x);
y, z, = f(200);
print(y);
print(z);
x = f(123);
print(x);
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_12.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used in assignment statement (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

x = f(123);
print(x);
7 changes: 7 additions & 0 deletions test/api/cli/functions/invalid_parser_13.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// invalid number of results when used in call expression (expecting 1, getting 0)

def f(a:si64) {
return;
}

print(f(123));
7 changes: 7 additions & 0 deletions test/api/cli/functions/invalid_parser_14.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// invalid number of results when used in call expression (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

print(f(123));
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_15.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used with binary operator (expecting 1, getting 0)

def f(a:si64) {
return;
}

x = f(123) + 1;
print(x);
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_16.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used with binary operator (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

x = f(123) + 1;
print(x);
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_17.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used in matrix literal (expecting 1, getting 0)

def f(a:si64) {
return;
}

x = [f(123)];
print(x);
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_18.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used in matrix literal (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

x = [f(123)];
print(x);
10 changes: 10 additions & 0 deletions test/api/cli/functions/invalid_parser_19.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// invalid number of results when used in if-statement (expecting 1, getting 0)

def f(a:si64) {
return;
}

if(f(123))
print("yes");
else
print("no");
10 changes: 10 additions & 0 deletions test/api/cli/functions/invalid_parser_20.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// invalid number of results when used in if-statement (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

if(f(123))
print("yes");
else
print("no");
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_21.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used in while-statement (expecting 1, getting 0)

def f(a:si64) {
return;
}

while(f(123))
print("abc");
8 changes: 8 additions & 0 deletions test/api/cli/functions/invalid_parser_22.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// invalid number of results when used in while-statement (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

while(f(123))
print("abc");
12 changes: 12 additions & 0 deletions test/api/cli/functions/invalid_parser_23.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// invalid number of results when used in return-statement (expecting 1, getting 0)

def f(a:si64) {
return;
}

def g() -> si64 {
return f(123);
}

x = g();
print(x);
12 changes: 12 additions & 0 deletions test/api/cli/functions/invalid_parser_24.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// invalid number of results when used in return-statement (expecting 1, getting 2)

def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

def g() -> si64 {
return f(123);
}

x = g();
print(x);
14 changes: 14 additions & 0 deletions test/api/cli/functions/invalid_parser_25.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// overloading typed functions with different number of results, but same arguments

def f(a:si64) -> si64 {
return a + 1;
}
def f(a:si64) -> si64, si64 {
return a + 1, a + 2;
}

x = f(100);
print(x);
y, z, = f(200);
print(y);
print(z);
6 changes: 3 additions & 3 deletions test/api/cli/functions/invalid_parser_8.daphne
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// invalid number of results (expecting 2, getting 1)
// invalid number of results when used in assignment statement (expecting 2, getting 0)

def f(a:si64) -> si64 {
return a + 1;
def f(a:si64) {
return;
}

x, y = f(123);
Expand Down
6 changes: 3 additions & 3 deletions test/api/cli/functions/invalid_parser_9.daphne
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// invalid number of results (expecting 2, getting 3)
// invalid number of results when used in assignment statement (expecting 2, getting 1)

def f(a:si64) -> si64, si64, si64 {
return a + 1, a + 2, a + 3;
def f(a:si64) -> si64 {
return a + 1;
}

x, y = f(123);
Expand Down

0 comments on commit 3e56ed5

Please sign in to comment.