Skip to content

Commit

Permalink
[DAPHNE-daphne-eu#595] Fix crash when calling print(f()) with empty f…
Browse files Browse the repository at this point in the history
…unction

The core issue seems to be that functions without return value internally returned an invalid mlir::Value. Upon doing anything with this value, the program would crash (e.g. accessing any fields or calling any methods). Now we return nullptr and check for that, which gives a somewhat more deterministic behavior.

Closes daphne-eu#595
  • Loading branch information
corepointer committed Jul 4, 2024
1 parent 5a47688 commit e1d244c
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ antlrcpp::Any DaphneDSLVisitor::visitImportStatement(DaphneDSLGrammarParser::Imp
symbolTable.put(origScope);

importedFiles = origImportedFiles;

for(std::pair<std::string, mlir::func::FuncOp> funcSymbol : functionsSymbolMap)
if(funcSymbol.first.find('.') == std::string::npos)
origFuncMap.insert({finalPrefix + funcSymbol.first, funcSymbol.second});
Expand Down Expand Up @@ -969,7 +969,7 @@ antlrcpp::Any DaphneDSLVisitor::handleMapOpCall(DaphneDSLGrammarParser::CallExpr
"called 'handleMapOpCall' for function "
+ func + " instead of 'map'"
);

if (ctx->expr().size() != 2) {
throw ErrorHandler::compilerError(loc, "DSLVisitor",
"built-in function 'map' expects exactly 2 argument(s), but got " +
Expand Down Expand Up @@ -1014,11 +1014,15 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo

// Parse arguments.
std::vector<mlir::Value> args_vec;
for(unsigned i = 0; i < ctx->expr().size(); i++)
args_vec.push_back(utils.valueOrError(visit(ctx->expr(i))));
for(unsigned i = 0; i < ctx->expr().size(); i++) {
auto a = visit(ctx->expr(i));
if(a.isNotNull()) {
auto b = utils.valueOrError(a);
args_vec.push_back(b);
}
}

auto maybeUDF = findMatchingUDF(func, args_vec, loc);

if (maybeUDF) {
if(hasKernelHint)
throw ErrorHandler::compilerError(
Expand All @@ -1028,26 +1032,24 @@ antlrcpp::Any DaphneDSLVisitor::visitCallExpr(DaphneDSLGrammarParser::CallExprCo
);

auto funcTy = maybeUDF->getFunctionType();
auto co = builder
.create<mlir::daphne::GenericCallOp>(loc,
maybeUDF->getSymName(),
args_vec,
funcTy.getResults());
auto co = builder.create<mlir::daphne::GenericCallOp>(loc, maybeUDF->getSymName(), args_vec, funcTy.getResults());
if(funcTy.getNumResults() > 1)
return co.getResults();
else if(funcTy.getNumResults() == 1)
return co.getResult(0);
else
// If the UDF has no return values, the value returned here
// is invalid. But that seems to be okay, since it is never
// used as a mlir::Value in that case.
return co.getResult(0);
return nullptr;
}

// Create DaphneIR operation for the built-in function.
antlrcpp::Any res = builtins.build(loc, func, args_vec);

if(hasKernelHint) {
std::string kernel = ctx->kernel->getText();

// We deliberately don't check if the specified kernel
// is registered for the created kind of operation,
// since this is checked in RewriteToCallKernelOpPass.
Expand Down Expand Up @@ -1487,7 +1489,7 @@ mlir::Value DaphneDSLVisitor::buildColMatrixFromValues(mlir::Location loc, const
// Maybe later these InsertOps can be fused into a single one
// or replaced with InsertOps that support scalar input.
result = static_cast<mlir::Value>(builder.create<mlir::daphne::InsertRowOp>(loc, utils.matrixOf(matrixVt),
result,
result,
ins,
builder.create<mlir::daphne::ConstantOp>(loc, idx),
builder.create<mlir::daphne::ConstantOp>(loc, idx+1)));
Expand All @@ -1496,7 +1498,7 @@ mlir::Value DaphneDSLVisitor::buildColMatrixFromValues(mlir::Location loc, const
return result;
}

antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::MatrixLiteralExprContext * ctx) {
antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::MatrixLiteralExprContext * ctx) {
mlir::Location loc = utils.getLoc(ctx->start);
mlir::Value rows;
mlir::Value cols;
Expand All @@ -1511,7 +1513,7 @@ antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::M
rows = builder.create<mlir::daphne::ConstantOp>(loc, static_cast<size_t>(ctx->expr().size()));
}
else {
numMatElems = (ctx->rows && ctx->cols) ? ctx->expr().size() - 2 : ctx->expr().size() - 1;
numMatElems = (ctx->rows && ctx->cols) ? ctx->expr().size() - 2 : ctx->expr().size() - 1;
if (ctx->cols && ctx->rows) {
cols = utils.valueOrError(visit(ctx->cols));
rows = utils.valueOrError(visit(ctx->rows));
Expand Down Expand Up @@ -1540,7 +1542,7 @@ antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::M
values.emplace_back(currentValue);
valueTypes.emplace_back(currentValue.getType());
}

mlir::Type valueType = mostGeneralVt(valueTypes);
mlir::Value colMatrix;

Expand Down Expand Up @@ -1581,7 +1583,7 @@ antlrcpp::Any DaphneDSLVisitor::visitMatrixLiteralExpr(DaphneDSLGrammarParser::M
else {
throw ErrorHandler::compilerError(loc, "DSLVisitor", "matrix literal of invalid value type");
}

// TODO: omit ReshapeOp if rows=1 (not always known at parse-time)
mlir::Value result = static_cast<mlir::Value>(builder.create<mlir::daphne::ReshapeOp>(loc, utils.matrixOf(valueType), colMatrix, rows, cols));

Expand All @@ -1593,7 +1595,7 @@ antlrcpp::Any DaphneDSLVisitor::visitColMajorFrameLiteralExpr(DaphneDSLGrammarPa

size_t labelCount = ctx->labels.size();
size_t colCount = ctx->cols.size();

if (labelCount != colCount)
throw ErrorHandler::compilerError(loc, "DSLVisitor", "frame literals must have an equal number of column labels and column matrices");

Expand All @@ -1619,7 +1621,7 @@ antlrcpp::Any DaphneDSLVisitor::visitColMajorFrameLiteralExpr(DaphneDSLGrammarPa
}

mlir::Type frameColTypes = mlir::daphne::FrameType::get(builder.getContext(), columnMatElemType);

mlir::Value result = static_cast<mlir::Value>(builder.create<mlir::daphne::CreateFrameOp>(loc, frameColTypes, columnMatrices, parsedLabels));

return result;
Expand Down Expand Up @@ -1722,7 +1724,7 @@ antlrcpp::Any DaphneDSLVisitor::visitRowMajorFrameLiteralExpr(DaphneDSLGrammarPa
}

mlir::Type frameColTypes = mlir::daphne::FrameType::get(builder.getContext(), colTypes);

mlir::Value result = static_cast<mlir::Value>(builder.create<mlir::daphne::CreateFrameOp>(loc, frameColTypes, colValues, parsedLabels));

return result;
Expand Down Expand Up @@ -2133,11 +2135,10 @@ antlrcpp::Any DaphneDSLVisitor::visitFunctionStatement(DaphneDSLGrammarParser::F
visitBlockStatement(ctx->bodyStmt);

rectifyEarlyReturns(funcBlock);
if(funcBlock->getOperations().empty()
|| !funcBlock->getOperations().back().hasTrait<mlir::OpTrait::IsTerminator>()) {

if(funcBlock->getOperations().empty() || !funcBlock->getOperations().back().hasTrait<mlir::OpTrait::IsTerminator>()) {
builder.create<mlir::daphne::ReturnOp>(utils.getLoc(ctx->stop));
}

auto terminator = funcBlock->getTerminator();
auto returnOpTypes = terminator->getOperandTypes();
if(!functionOperation) {
Expand Down

0 comments on commit e1d244c

Please sign in to comment.