Skip to content

Commit 991397c

Browse files
pashu123Icohedron
authored andcommitted
[mlir][amdgpu] Support for 8bit extf for 0d vector type (llvm#126102)
For 0d vector type the rewrite crashes.
1 parent 2f7947c commit 991397c

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,23 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
102102
return rewriter.replaceOp(op, result);
103103
}
104104
int64_t numElements = inType.getNumElements();
105+
105106
Value zero = rewriter.create<arith::ConstantOp>(
106107
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
108+
VectorType outType = cast<VectorType>(op.getOut().getType());
109+
107110
if (inType.getShape().empty()) {
111+
Value zerodSplat =
112+
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
108113
Value scalarIn =
109114
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
110-
// Recurse to send the 0-D vector case to the 1-D vector case
111115
Value scalarExt =
112116
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
113-
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
117+
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
114118
ArrayRef<int64_t>{});
115119
return rewriter.replaceOp(op, result);
116120
}
117121

118-
VectorType outType = cast<VectorType>(op.getOut().getType());
119122
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
120123
outType.getElementType());
121124
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);

mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir

+13-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@ func.func @scalar_ext(%v: f8E5M2FNUZ) -> f16 {
1010
return %w : f16
1111
}
1212

13-
// No 0-D test because arith.extf hasn't been extended to support it.
13+
// -----
14+
15+
// CHECK-LABEL: func.func @vector_zero_d(
16+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: vector<f8E5M2FNUZ>) -> vector<f32>
17+
// CHECK: %[[CONST:.+]] = arith.constant dense<0.000000e+00> : vector<f32>
18+
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[ARG0]][] : f8E5M2FNUZ from vector<f8E5M2FNUZ>
19+
// CHECK: %[[CONVERT:.+]] = amdgpu.ext_packed_fp8 %[[EXTRACT]][0] : f8E5M2FNUZ to f32
20+
// CHECK: %[[RESULT:.+]] = vector.insert %[[CONVERT]], %[[CONST]] [] : f32 into vector<f32>
21+
// CHECK: return %[[RESULT]] : vector<f32>
22+
func.func @vector_zero_d(%v: vector<f8E5M2FNUZ>) -> vector<f32> {
23+
%w = arith.extf %v : vector<f8E5M2FNUZ> to vector<f32>
24+
return %w : vector<f32>
25+
}
1426

1527
// -----
1628

0 commit comments

Comments
 (0)