diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 33370566996ee..60a002c41bfb2 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -102,20 +102,23 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, return rewriter.replaceOp(op, result); } int64_t numElements = inType.getNumElements(); + Value zero = rewriter.create( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + VectorType outType = cast(op.getOut().getType()); + if (inType.getShape().empty()) { + Value zerodSplat = + rewriter.createOrFold(loc, outType, zero); Value scalarIn = rewriter.create(loc, in, ArrayRef{}); - // Recurse to send the 0-D vector case to the 1-D vector case Value scalarExt = rewriter.create(loc, outElemType, scalarIn); - Value result = rewriter.create(loc, scalarExt, zero, + Value result = rewriter.create(loc, scalarExt, zerodSplat, ArrayRef{}); return rewriter.replaceOp(op, result); } - VectorType outType = cast(op.getOut().getType()); VectorType flatTy = VectorType::get(SmallVector{numElements}, outType.getElementType()); Value result = rewriter.createOrFold(loc, flatTy, zero); diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir index bd90facb61544..985fb532ea74a 100644 --- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir +++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir @@ -10,7 +10,19 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 { return %w : f16 } -// No 0-D test because arith.extf hasn't been extended to support it. +// ----- + +// CHECK-LABEL: func.func @vector_zero_d( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector) -> vector +// CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector +// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32 +// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector +// CHECK: return %[[RESULT]] : vector +func.func @vector_zero_d(%v: vector) -> vector { + %w = arith.extf %v : vector to vector + return %w : vector +} // -----