diff --git a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp index 2d8062fe7..450e14767 100644 --- a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp @@ -64,19 +64,47 @@ Type getFrameColumnTypeByLabel(daphne::FrameType ft, Value labelVal) { // **************************************************************************** std::vector daphne::CastOp::inferTypes() { - auto ftArg = getArg().getType().dyn_cast(); - auto mtRes = getRes().getType().dyn_cast(); - if(ftArg && mtRes && mtRes.getElementType().isa()) { - std::vector ctsArg = ftArg.getColumnTypes(); - if(ctsArg.size() == 1) - return {daphne::MatrixType::get(getContext(), ctsArg[0])}; + Type argTy = getArg().getType(); + Type resTy = getRes().getType(); + auto mtArg = argTy.dyn_cast(); + auto ftArg = argTy.dyn_cast(); + auto mtRes = resTy.dyn_cast(); + + // If the result type is a matrix with so far unknown value type, then we + // infer the value type from the argument. + if(mtRes && mtRes.getElementType().isa()) { + Type resVt; + + if(mtArg) + // The argument is a matrix; we use its value type for the result. + resVt = mtArg.getElementType(); + else if(ftArg) { + // The argument is a frame, we use the value type of its only + // column for the results; if the argument has more than one + // column, we throw an exception. + std::vector ctsArg = ftArg.getColumnTypes(); + if(ctsArg.size() == 1) + resVt = ctsArg[0]; + else + // TODO We could use the most general of the column types. + throw std::runtime_error( + "currently CastOp cannot infer the value type of its " + "output matrix, if the input is a multi-column frame" + ); + } else - throw std::runtime_error( - "currently CastOp cannot infer the value type of its " - "output matrix, if the input is a multi-column frame" - ); + // The argument is a scalar, we use its type for the value type + // of the result. + // TODO double-check if it is really a scalar + resVt = argTy; + + return {daphne::MatrixType::get(getContext(), resVt)}; } - return {daphne::UnknownType::get(getContext())}; + + // Otherwise, we leave the result type as it is. We do not reset it to + // unknown, since this could drop information that was explicitly + // encoded in the CastOp. + return {resTy}; } std::vector daphne::ExtractColOp::inferTypes() {