@@ -102,20 +102,23 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
102
102
return rewriter.replaceOp (op, result);
103
103
}
104
104
int64_t numElements = inType.getNumElements ();
105
+
105
106
Value zero = rewriter.create <arith::ConstantOp>(
106
107
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
108
+ VectorType outType = cast<VectorType>(op.getOut ().getType ());
109
+
107
110
if (inType.getShape ().empty ()) {
111
+ Value zerodSplat =
112
+ rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
108
113
Value scalarIn =
109
114
rewriter.create <vector::ExtractOp>(loc, in, ArrayRef<int64_t >{});
110
- // Recurse to send the 0-D vector case to the 1-D vector case
111
115
Value scalarExt =
112
116
rewriter.create <arith::ExtFOp>(loc, outElemType, scalarIn);
113
- Value result = rewriter.create <vector::InsertOp>(loc, scalarExt, zero ,
117
+ Value result = rewriter.create <vector::InsertOp>(loc, scalarExt, zerodSplat ,
114
118
ArrayRef<int64_t >{});
115
119
return rewriter.replaceOp (op, result);
116
120
}
117
121
118
- VectorType outType = cast<VectorType>(op.getOut ().getType ());
119
122
VectorType flatTy = VectorType::get (SmallVector<int64_t >{numElements},
120
123
outType.getElementType ());
121
124
Value result = rewriter.createOrFold <vector::SplatOp>(loc, flatTy, zero);
0 commit comments