Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Introduce poison in LowerVectorBitCast/Broadcast/Transpose #126180

Merged
merged 1 commit into from
Feb 7, 2025

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Feb 7, 2025

This PR continues with the introduction of poison as initialization
vector, in this particular case, in LowerVectorBitCast,
LowerVectorBroadcast and LowerVectorTranspose.

@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This PR continues with the introduction of poison as initialization
vector, in this particular case, in LowerVectorBitCast,
LowerVectorBroadcast and LowerVectorTranspose.


Patch is 29.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126180.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-16)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+3-12)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+18-18)
  • (modified) mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir (+3-3)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+21-21)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index d8c4939dc742a7..89930a6bd35fa3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -32,7 +33,7 @@ namespace {
 ///
 /// Would be unrolled to:
 ///
-/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %result = ub.poison : vector<1x2x3x8xi32>
 /// %0 = vector.extract %a[0, 0, 0]                 ─┐
 ///        : vector<4xi64> from vector<1x2x3x4xi64>  |
 /// %1 = vector.bitcast %0                           | - Repeated 6x for
@@ -63,8 +64,7 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
         VectorType::get(shape, resultType.getElementType(), scalableDims);
 
     Location loc = op.getLoc();
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultType, rewriter.getZeroAttr(resultType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
     for (auto position : *unrollIterator) {
       Value extract =
           rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..fec3c6c52e5e43 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -11,27 +11,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
 #define DEBUG_TYPE "vector-broadcast-lowering"
 
@@ -86,8 +75,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
       VectorType resType = VectorType::Builder(dstType).dropDim(0);
       Value bcst =
           rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
-      Value result = rewriter.create<arith::ConstantOp>(
-          loc, dstType, rewriter.getZeroAttr(dstType));
+      Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
       rewriter.replaceOp(op, result);
@@ -127,8 +115,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     VectorType resType =
         VectorType::get(dstType.getShape().drop_front(), eltType,
                         dstType.getScalableDims().drop_front());
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, dstType, rewriter.getZeroAttr(dstType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
     if (m == 0) {
       // Stetch at start.
       Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 3c92b222e6bc80..d05c08bf640077 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,26 +11,19 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
 #define DEBUG_TYPE "lower-vector-transpose"
 
@@ -291,8 +284,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
 
   auto reshInputType = VectorType::get(
       {m, n}, cast<VectorType>(source.getType()).getElementType());
-  Value res =
-      b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
+  Value res = b.create<ub::PoisonOp>(reshInputType);
   for (int64_t i = 0; i < m; ++i)
     res = b.create<vector::InsertOp>(vs[i], res, i);
   return res;
@@ -368,8 +360,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     // of the leftmost transposed dimensions. We traverse every transpose
     // element using a linearized index that we delinearize to generate the
     // appropriate indices for the extract/insert operations.
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resType, rewriter.getZeroAttr(resType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resType);
     int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
 
     for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index a5ca2b3764ffd4..0eb14aa295d9d2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -286,7 +286,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
 // CHECK-LABEL: @broadcast_vec2d_from_vec0d(
 // CHECK-SAME:  %[[A:.*]]: vector<f32>)
 //       CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
-//       CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+//       CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
 //       CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
 //       CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
 //       CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
@@ -306,7 +306,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
 }
 // CHECK-LABEL: @broadcast_vec2d_from_vec1d(
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK:       %[[T0:.*]] = ub.poison : vector<3x2xf32>
 // CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
 // CHECK:       %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<2xf32>>
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<2xf32>>
@@ -322,7 +322,7 @@ func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
 }
 // CHECK-LABEL: @broadcast_vec2d_from_vec1d_scalable(
 // CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK:       %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
 // CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
 // CHECK:       %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
@@ -339,7 +339,7 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
 // CHECK-LABEL: @broadcast_vec2d_from_index_vec1d(
 // CHECK-SAME:  %[[A:.*]]: vector<2xindex>)
 // CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
-// CHECK:       %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex>
+// CHECK:       %[[T0:.*]] = ub.poison : vector<3x2xindex>
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>>
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>>
 
@@ -355,7 +355,7 @@ func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -
 // CHECK-LABEL: @broadcast_vec2d_from_index_vec1d_scalable(
 // CHECK-SAME:  %[[A:.*]]: vector<[2]xindex>)
 // CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
-// CHECK:       %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
+// CHECK:       %[[T0:.*]] = ub.poison : vector<3x[2]xindex>
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
 
@@ -370,9 +370,9 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
 }
 // CHECK-LABEL: @broadcast_vec3d_from_vec1d(
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
-// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG:   %[[T0:.*]] = ub.poison : vector<3x2xf32>
 // CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG:   %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
 // CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
 
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
@@ -395,9 +395,9 @@ func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
 }
 // CHECK-LABEL: @broadcast_vec3d_from_vec1d_scalable(
 // CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
-// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK-DAG:   %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
 // CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
-// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK-DAG:   %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
 // CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
 
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
@@ -421,7 +421,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
 // CHECK-LABEL: @broadcast_vec3d_from_vec2d(
 // CHECK-SAME:  %[[A:.*]]: vector<3x2xf32>)
 // CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK:       %[[T0:.*]] = ub.poison : vector<4x3x2xf32>
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
 // CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
@@ -439,7 +439,7 @@ func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vect
 // CHECK-LABEL: @broadcast_vec3d_from_vec2d_scalable(
 // CHECK-SAME:  %[[A:.*]]: vector<3x[2]xf32>)
 // CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK:       %[[T0:.*]] = ub.poison : vector<4x3x[2]xf32>
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
 // CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
@@ -486,7 +486,7 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
 // CHECK-LABEL: @broadcast_stretch_at_start(
 // CHECK-SAME:  %[[A:.*]]: vector<1x4xf32>)
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
+// CHECK:       %[[T1:.*]] = ub.poison : vector<3x4xf32>
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x4xf32> to !llvm.array<3 x vector<4xf32>>
 // CHECK:       %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>>
 // CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<4xf32>>
@@ -504,7 +504,7 @@ func.func @broadcast_stretch_at_start_scalable(%arg0: vector<1x[4]xf32>) -> vect
 // CHECK-LABEL: @broadcast_stretch_at_start_scalable(
 // CHECK-SAME:  %[[A:.*]]: vector<1x[4]xf32>)
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x[4]xf32> to !llvm.array<1 x vector<[4]xf32>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x[4]xf32>
+// CHECK:       %[[T1:.*]] = ub.poison : vector<3x[4]xf32>
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x[4]xf32> to !llvm.array<3 x vector<[4]xf32>>
 // CHECK:       %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<[4]xf32>>
 // CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<[4]xf32>>
@@ -522,7 +522,7 @@ func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
 // CHECK-LABEL: @broadcast_stretch_at_end(
 // CHECK-SAME:  %[[A:.*]]: vector<4x1xf32>)
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
+// CHECK:       %[[T1:.*]] = ub.poison : vector<4x3xf32>
 // CHECK:       %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
 // CHECK:       %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
 // CHECK:       %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -570,9 +570,9 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
 // CHECK-LABEL: @broadcast_stretch_in_middle(
 // CHECK-SAME:  %[[A:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
 // CHECK:       %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK:       %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
 // CHECK:       %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK:       %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK:       %[[T2:.*]] = ub.poison : vector<3x2xf32>
 // CHECK:       %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
 // CHECK:       %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
 // CHECK:       %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
@@ -606,9 +606,9 @@ func.func @broadcast_stretch_in_middle_scalable_v1(%arg0: vector<4x1x[2]xf32>) -
 // CHECK-LABEL: @broadcast_stretch_in_middle_scalable_v1(
 // CHECK-SAME:  %[[A:.*]]: vector<4x1x[2]xf32>) -> vector<4x3x[2]xf32> {
 // CHECK:       %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x[2]xf32> to !llvm.array<4 x array<1 x vector<[2]xf32>>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK:       %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
 // CHECK:       %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
-// CHECK:       %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK:       %[[T2:.*]] = ub.poison : vector<3x[2]xf32>
 // CHECK:       %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
 // CHECK:       %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<[2]xf32>>>
 // CHECK:       %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<[2]xf32>>
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index 346291019451c7..29e7007666e878 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -24,7 +24,7 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
 }
 // CHECK-LABEL: func.func @vector_bitcast_2d
 // CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
-// CHECK:         %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
+// CHECK:         %[[INIT:.+]] = ub.poison : vector<2x2xi64>
 // CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
 // CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
 // CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
@@ -39,7 +39,7 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
 }
 // CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
 // CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
-// CHECK:         %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
+// CHECK:         %[[INIT:.+]] = ub.poison : vector<1x2x[3]x8xi32>
 // CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
 // CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
 // CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
@@ -54,7 +54,7 @@ func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) ->
 }
 // CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
 // CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
-// CHECK:         %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
+// CHECK:         %[[INIT:.+]] = ub.poison : vector<2x[4]xi32>
 // CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
 // CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
 // CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vecto...
[truncated]

@dcaballe dcaballe requested a review from kuhar February 7, 2025 05:43
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

…pose

This PR continues with the introduction of poison as initialization
vector, in this particular case, in LowerVectorBitCast,
LowerVectorBroadcast and LowerVectorTranspose.
@dcaballe dcaballe force-pushed the init-vector-insert-w-poison1 branch from 68d203b to 312077c Compare February 7, 2025 18:47
@dcaballe dcaballe merged commit 2c4dd89 into llvm:main Feb 7, 2025
5 of 6 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…pose (llvm#126180)

This PR continues with the introduction of poison as initialization
vector, in this particular case, in LowerVectorBitCast,
LowerVectorBroadcast and LowerVectorTranspose.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants