Skip to content

Commit 2c4dd89

Browse files
authored
[mlir][Vector] Introduce poison in LowerVectorBitCast/Broadcast/Transpose (#126180)
This PR continues with the introduction of poison as initialization vector, in this particular case, in LowerVectorBitCast, LowerVectorBroadcast and LowerVectorTranspose.
1 parent e566313 commit 2c4dd89

7 files changed

+53
-75
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1415
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1617
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -32,7 +33,7 @@ namespace {
3233
///
3334
/// Would be unrolled to:
3435
///
35-
/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
36+
/// %result = ub.poison : vector<1x2x3x8xi32>
3637
/// %0 = vector.extract %a[0, 0, 0] ─┐
3738
/// : vector<4xi64> from vector<1x2x3x4xi64> |
3839
/// %1 = vector.bitcast %0 | - Repeated 6x for
@@ -63,8 +64,7 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
6364
VectorType::get(shape, resultType.getElementType(), scalableDims);
6465

6566
Location loc = op.getLoc();
66-
Value result = rewriter.create<arith::ConstantOp>(
67-
loc, resultType, rewriter.getZeroAttr(resultType));
67+
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
6868
for (auto position : *unrollIterator) {
6969
Value extract =
7070
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

+3-16
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,16 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15-
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/Arith/Utils/Utils.h"
17-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1814
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
#include "mlir/Dialect/SCF/IR/SCF.h"
20-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21-
#include "mlir/Dialect/Utils/IndexingUtils.h"
22-
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
15+
#include "mlir/Dialect/UB/IR/UBOps.h"
2316
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2417
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2518
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2619
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
2820
#include "mlir/IR/BuiltinTypes.h"
29-
#include "mlir/IR/ImplicitLocOpBuilder.h"
3021
#include "mlir/IR/Location.h"
31-
#include "mlir/IR/Matchers.h"
3222
#include "mlir/IR/PatternMatch.h"
3323
#include "mlir/IR/TypeUtilities.h"
34-
#include "mlir/Interfaces/VectorInterfaces.h"
3524

3625
#define DEBUG_TYPE "vector-broadcast-lowering"
3726

@@ -86,8 +75,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
8675
VectorType resType = VectorType::Builder(dstType).dropDim(0);
8776
Value bcst =
8877
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
89-
Value result = rewriter.create<arith::ConstantOp>(
90-
loc, dstType, rewriter.getZeroAttr(dstType));
78+
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
9179
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
9280
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
9381
rewriter.replaceOp(op, result);
@@ -127,8 +115,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
127115
VectorType resType =
128116
VectorType::get(dstType.getShape().drop_front(), eltType,
129117
dstType.getScalableDims().drop_front());
130-
Value result = rewriter.create<arith::ConstantOp>(
131-
loc, dstType, rewriter.getZeroAttr(dstType));
118+
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
132119
if (m == 0) {
133120
// Stetch at start.
134121
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

+3-12
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,19 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1514
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/Arith/Utils/Utils.h"
17-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1815
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
#include "mlir/Dialect/SCF/IR/SCF.h"
20-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/UB/IR/UBOps.h"
2117
#include "mlir/Dialect/Utils/IndexingUtils.h"
2218
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2319
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2420
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2521
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
2722
#include "mlir/IR/BuiltinTypes.h"
2823
#include "mlir/IR/ImplicitLocOpBuilder.h"
2924
#include "mlir/IR/Location.h"
30-
#include "mlir/IR/Matchers.h"
3125
#include "mlir/IR/PatternMatch.h"
3226
#include "mlir/IR/TypeUtilities.h"
33-
#include "mlir/Interfaces/VectorInterfaces.h"
3427

3528
#define DEBUG_TYPE "lower-vector-transpose"
3629

@@ -291,8 +284,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
291284

292285
auto reshInputType = VectorType::get(
293286
{m, n}, cast<VectorType>(source.getType()).getElementType());
294-
Value res =
295-
b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
287+
Value res = b.create<ub::PoisonOp>(reshInputType);
296288
for (int64_t i = 0; i < m; ++i)
297289
res = b.create<vector::InsertOp>(vs[i], res, i);
298290
return res;
@@ -368,8 +360,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
368360
// of the leftmost transposed dimensions. We traverse every transpose
369361
// element using a linearized index that we delinearize to generate the
370362
// appropriate indices for the extract/insert operations.
371-
Value result = rewriter.create<arith::ConstantOp>(
372-
loc, resType, rewriter.getZeroAttr(resType));
363+
Value result = rewriter.create<ub::PoisonOp>(loc, resType);
373364
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
374365

375366
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+18-18
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
185185
// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
186186
// CHECK-SAME: %[[A:.*]]: vector<f32>)
187187
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
188-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
188+
// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
189189
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
190190
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
191191
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
@@ -205,7 +205,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
205205
}
206206
// CHECK-LABEL: @broadcast_vec2d_from_vec1d(
207207
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
208-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
208+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x2xf32>
209209
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
210210
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<2xf32>>
211211
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<2xf32>>
@@ -221,7 +221,7 @@ func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
221221
}
222222
// CHECK-LABEL: @broadcast_vec2d_from_vec1d_scalable(
223223
// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
224-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
224+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
225225
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
226226
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
227227
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
@@ -238,7 +238,7 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
238238
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d(
239239
// CHECK-SAME: %[[A:.*]]: vector<2xindex>)
240240
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
241-
// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex>
241+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x2xindex>
242242
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>>
243243
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>>
244244

@@ -254,7 +254,7 @@ func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -
254254
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d_scalable(
255255
// CHECK-SAME: %[[A:.*]]: vector<[2]xindex>)
256256
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
257-
// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
257+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x[2]xindex>
258258
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
259259
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
260260

@@ -269,9 +269,9 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
269269
}
270270
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
271271
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
272-
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
272+
// CHECK-DAG: %[[T0:.*]] = ub.poison : vector<3x2xf32>
273273
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
274-
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
274+
// CHECK-DAG: %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
275275
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
276276

277277
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
@@ -294,9 +294,9 @@ func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
294294
}
295295
// CHECK-LABEL: @broadcast_vec3d_from_vec1d_scalable(
296296
// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
297-
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
297+
// CHECK-DAG: %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
298298
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
299-
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
299+
// CHECK-DAG: %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
300300
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
301301

302302
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
@@ -320,7 +320,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
320320
// CHECK-LABEL: @broadcast_vec3d_from_vec2d(
321321
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>)
322322
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
323-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
323+
// CHECK: %[[T0:.*]] = ub.poison : vector<4x3x2xf32>
324324
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
325325
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
326326
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
@@ -338,7 +338,7 @@ func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vect
338338
// CHECK-LABEL: @broadcast_vec3d_from_vec2d_scalable(
339339
// CHECK-SAME: %[[A:.*]]: vector<3x[2]xf32>)
340340
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
341-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
341+
// CHECK: %[[T0:.*]] = ub.poison : vector<4x3x[2]xf32>
342342
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
343343
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
344344
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
@@ -385,7 +385,7 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
385385
// CHECK-LABEL: @broadcast_stretch_at_start(
386386
// CHECK-SAME: %[[A:.*]]: vector<1x4xf32>)
387387
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
388-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
388+
// CHECK: %[[T1:.*]] = ub.poison : vector<3x4xf32>
389389
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x4xf32> to !llvm.array<3 x vector<4xf32>>
390390
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>>
391391
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<4xf32>>
@@ -403,7 +403,7 @@ func.func @broadcast_stretch_at_start_scalable(%arg0: vector<1x[4]xf32>) -> vect
403403
// CHECK-LABEL: @broadcast_stretch_at_start_scalable(
404404
// CHECK-SAME: %[[A:.*]]: vector<1x[4]xf32>)
405405
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x[4]xf32> to !llvm.array<1 x vector<[4]xf32>>
406-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x[4]xf32>
406+
// CHECK: %[[T1:.*]] = ub.poison : vector<3x[4]xf32>
407407
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x[4]xf32> to !llvm.array<3 x vector<[4]xf32>>
408408
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<[4]xf32>>
409409
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<[4]xf32>>
@@ -421,7 +421,7 @@ func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
421421
// CHECK-LABEL: @broadcast_stretch_at_end(
422422
// CHECK-SAME: %[[A:.*]]: vector<4x1xf32>)
423423
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
424-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
424+
// CHECK: %[[T1:.*]] = ub.poison : vector<4x3xf32>
425425
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
426426
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
427427
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -469,9 +469,9 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
469469
// CHECK-LABEL: @broadcast_stretch_in_middle(
470470
// CHECK-SAME: %[[A:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
471471
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
472-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
472+
// CHECK: %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
473473
// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
474-
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
474+
// CHECK: %[[T2:.*]] = ub.poison : vector<3x2xf32>
475475
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
476476
// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
477477
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
@@ -505,9 +505,9 @@ func.func @broadcast_stretch_in_middle_scalable_v1(%arg0: vector<4x1x[2]xf32>) -
505505
// CHECK-LABEL: @broadcast_stretch_in_middle_scalable_v1(
506506
// CHECK-SAME: %[[A:.*]]: vector<4x1x[2]xf32>) -> vector<4x3x[2]xf32> {
507507
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x[2]xf32> to !llvm.array<4 x array<1 x vector<[2]xf32>>>
508-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
508+
// CHECK: %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
509509
// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
510-
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
510+
// CHECK: %[[T2:.*]] = ub.poison : vector<3x[2]xf32>
511511
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
512512
// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<[2]xf32>>>
513513
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<[2]xf32>>

mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
2424
}
2525
// CHECK-LABEL: func.func @vector_bitcast_2d
2626
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
27-
// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
27+
// CHECK: %[[INIT:.+]] = ub.poison : vector<2x2xi64>
2828
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
2929
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
3030
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
@@ -39,7 +39,7 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
3939
}
4040
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
4141
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
42-
// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
42+
// CHECK: %[[INIT:.+]] = ub.poison : vector<1x2x[3]x8xi32>
4343
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
4444
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
4545
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
@@ -54,7 +54,7 @@ func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) ->
5454
}
5555
// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
5656
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
57-
// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
57+
// CHECK: %[[INIT:.+]] = ub.poison : vector<2x[4]xi32>
5858
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
5959
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
6060
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32>

0 commit comments

Comments
 (0)