-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Introduce poison in LowerVectorBitCast/Broadcast/Transpose #126180
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR continues with the introduction of poison as initialization Patch is 29.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126180.diff 7 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index d8c4939dc742a7..89930a6bd35fa3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -32,7 +33,7 @@ namespace {
///
/// Would be unrolled to:
///
-/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %result = ub.poison : vector<1x2x3x8xi32>
/// %0 = vector.extract %a[0, 0, 0] ─┐
/// : vector<4xi64> from vector<1x2x3x4xi64> |
/// %1 = vector.bitcast %0 | - Repeated 6x for
@@ -63,8 +64,7 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
VectorType::get(shape, resultType.getElementType(), scalableDims);
Location loc = op.getLoc();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultType, rewriter.getZeroAttr(resultType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
for (auto position : *unrollIterator) {
Value extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..fec3c6c52e5e43 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -11,27 +11,16 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
#define DEBUG_TYPE "vector-broadcast-lowering"
@@ -86,8 +75,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
rewriter.replaceOp(op, result);
@@ -127,8 +115,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType,
dstType.getScalableDims().drop_front());
- Value result = rewriter.create<arith::ConstantOp>(
- loc, dstType, rewriter.getZeroAttr(dstType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
if (m == 0) {
// Stetch at start.
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80..d05c08bf640077 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,26 +11,19 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
#define DEBUG_TYPE "lower-vector-transpose"
@@ -291,8 +284,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
auto reshInputType = VectorType::get(
{m, n}, cast<VectorType>(source.getType()).getElementType());
- Value res =
- b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
+ Value res = b.create<ub::PoisonOp>(reshInputType);
for (int64_t i = 0; i < m; ++i)
res = b.create<vector::InsertOp>(vs[i], res, i);
return res;
@@ -368,8 +360,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
// of the leftmost transposed dimensions. We traverse every transpose
// element using a linearized index that we delinearize to generate the
// appropriate indices for the extract/insert operations.
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resType, rewriter.getZeroAttr(resType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, resType);
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a5ca2b3764ffd4..0eb14aa295d9d2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -286,7 +286,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
// CHECK-SAME: %[[A:.*]]: vector<f32>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
@@ -306,7 +306,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
}
// CHECK-LABEL: @broadcast_vec2d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T0:.*]] = ub.poison : vector<3x2xf32>
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<2xf32>>
@@ -322,7 +322,7 @@ func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
}
// CHECK-LABEL: @broadcast_vec2d_from_vec1d_scalable(
// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK: %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
@@ -339,7 +339,7 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xindex>)
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
-// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex>
+// CHECK: %[[T0:.*]] = ub.poison : vector<3x2xindex>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>>
@@ -355,7 +355,7 @@ func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d_scalable(
// CHECK-SAME: %[[A:.*]]: vector<[2]xindex>)
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
-// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
+// CHECK: %[[T0:.*]] = ub.poison : vector<3x[2]xindex>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
@@ -370,9 +370,9 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
-// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG: %[[T0:.*]] = ub.poison : vector<3x2xf32>
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG: %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
@@ -395,9 +395,9 @@ func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d_scalable(
// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
-// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK-DAG: %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
-// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK-DAG: %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
@@ -421,7 +421,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
// CHECK-LABEL: @broadcast_vec3d_from_vec2d(
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>)
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[T0:.*]] = ub.poison : vector<4x3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
@@ -439,7 +439,7 @@ func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vect
// CHECK-LABEL: @broadcast_vec3d_from_vec2d_scalable(
// CHECK-SAME: %[[A:.*]]: vector<3x[2]xf32>)
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK: %[[T0:.*]] = ub.poison : vector<4x3x[2]xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
@@ -486,7 +486,7 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
// CHECK-LABEL: @broadcast_stretch_at_start(
// CHECK-SAME: %[[A:.*]]: vector<1x4xf32>)
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
+// CHECK: %[[T1:.*]] = ub.poison : vector<3x4xf32>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x4xf32> to !llvm.array<3 x vector<4xf32>>
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>>
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<4xf32>>
@@ -504,7 +504,7 @@ func.func @broadcast_stretch_at_start_scalable(%arg0: vector<1x[4]xf32>) -> vect
// CHECK-LABEL: @broadcast_stretch_at_start_scalable(
// CHECK-SAME: %[[A:.*]]: vector<1x[4]xf32>)
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x[4]xf32> to !llvm.array<1 x vector<[4]xf32>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x[4]xf32>
+// CHECK: %[[T1:.*]] = ub.poison : vector<3x[4]xf32>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x[4]xf32> to !llvm.array<3 x vector<[4]xf32>>
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<[4]xf32>>
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<[4]xf32>>
@@ -522,7 +522,7 @@ func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
// CHECK-LABEL: @broadcast_stretch_at_end(
// CHECK-SAME: %[[A:.*]]: vector<4x1xf32>)
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
+// CHECK: %[[T1:.*]] = ub.poison : vector<4x3xf32>
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -570,9 +570,9 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
// CHECK-LABEL: @broadcast_stretch_in_middle(
// CHECK-SAME: %[[A:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T2:.*]] = ub.poison : vector<3x2xf32>
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
@@ -606,9 +606,9 @@ func.func @broadcast_stretch_in_middle_scalable_v1(%arg0: vector<4x1x[2]xf32>) -
// CHECK-LABEL: @broadcast_stretch_in_middle_scalable_v1(
// CHECK-SAME: %[[A:.*]]: vector<4x1x[2]xf32>) -> vector<4x3x[2]xf32> {
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x[2]xf32> to !llvm.array<4 x array<1 x vector<[2]xf32>>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK: %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
-// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK: %[[T2:.*]] = ub.poison : vector<3x[2]xf32>
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<[2]xf32>>>
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<[2]xf32>>
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index 346291019451c7..29e7007666e878 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -24,7 +24,7 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
}
// CHECK-LABEL: func.func @vector_bitcast_2d
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
-// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
+// CHECK: %[[INIT:.+]] = ub.poison : vector<2x2xi64>
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
@@ -39,7 +39,7 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
}
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
-// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
+// CHECK: %[[INIT:.+]] = ub.poison : vector<1x2x[3]x8xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
@@ -54,7 +54,7 @@ func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) ->
}
// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
-// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
+// CHECK: %[[INIT:.+]] = ub.poison : vector<2x[4]xi32>
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vecto...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
…pose This PR continues with the introduction of poison as initialization vector, in this particular case, in LowerVectorBitCast, LowerVectorBroadcast and LowerVectorTranspose.
68d203b
to
312077c
Compare
…pose (llvm#126180) This PR continues with the introduction of poison as initialization vector, in this particular case, in LowerVectorBitCast, LowerVectorBroadcast and LowerVectorTranspose.
This PR continues with the introduction of poison as initialization
vector, in this particular case, in LowerVectorBitCast,
LowerVectorBroadcast and LowerVectorTranspose.