Skip to content

Commit

Permalink
Improved type inference for CastOp.
Browse files Browse the repository at this point in the history
- Type inference for casts is not trivial, because the result type is usually explicitly given (main purpose of a cast).
- However, when the data type of the result is given, but the value type is unknown, we should infer the value type from the argument.
- So far, this was only done when the result is a matrix and the argument is a single-column frame.
- Furthermore, the result type was reset to unknown if we could not infer anything.
- This commit introduces two improvements:
  - If the result type is a matrix of unknown value type, the value type is also inferred for matrix and scalar arguments.
  - If we decide not to infer the result type, we do not reset it to unknown, but leave it as it is in order not to remove any information (e.g., on the data type) that was explicitly given.
  • Loading branch information
pdamme committed Jun 6, 2023
1 parent 66c9324 commit 19b2b44
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions src/ir/daphneir/DaphneInferTypesOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,47 @@ Type getFrameColumnTypeByLabel(daphne::FrameType ft, Value labelVal) {
// ****************************************************************************

std::vector<Type> daphne::CastOp::inferTypes() {
auto ftArg = getArg().getType().dyn_cast<daphne::FrameType>();
auto mtRes = getRes().getType().dyn_cast<daphne::MatrixType>();
if(ftArg && mtRes && mtRes.getElementType().isa<daphne::UnknownType>()) {
std::vector<Type> ctsArg = ftArg.getColumnTypes();
if(ctsArg.size() == 1)
return {daphne::MatrixType::get(getContext(), ctsArg[0])};
Type argTy = getArg().getType();
Type resTy = getRes().getType();
auto mtArg = argTy.dyn_cast<daphne::MatrixType>();
auto ftArg = argTy.dyn_cast<daphne::FrameType>();
auto mtRes = resTy.dyn_cast<daphne::MatrixType>();

// If the result type is a matrix with so far unknown value type, then we
// infer the value type from the argument.
if(mtRes && mtRes.getElementType().isa<daphne::UnknownType>()) {
Type resVt;

if(mtArg)
// The argument is a matrix; we use its value type for the result.
resVt = mtArg.getElementType();
else if(ftArg) {
// The argument is a frame, we use the value type of its only
// column for the results; if the argument has more than one
// column, we throw an exception.
std::vector<Type> ctsArg = ftArg.getColumnTypes();
if(ctsArg.size() == 1)
resVt = ctsArg[0];
else
// TODO We could use the most general of the column types.
throw std::runtime_error(
"currently CastOp cannot infer the value type of its "
"output matrix, if the input is a multi-column frame"
);
}
else
throw std::runtime_error(
"currently CastOp cannot infer the value type of its "
"output matrix, if the input is a multi-column frame"
);
// The argument is a scalar, we use its type for the value type
// of the result.
// TODO double-check if it is really a scalar
resVt = argTy;

return {daphne::MatrixType::get(getContext(), resVt)};
}
return {daphne::UnknownType::get(getContext())};

// Otherwise, we leave the result type as it is. We do not reset it to
// unknown, since this could drop information that was explicitly
// encoded in the CastOp.
return {resTy};
}

std::vector<Type> daphne::ExtractColOp::inferTypes() {
Expand Down

0 comments on commit 19b2b44

Please sign in to comment.