Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Generate poison vectors in vector.shape_cast lowering #125613

Merged
merged 4 commits into from
Feb 7, 2025

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Feb 4, 2025

This is the first PR that introduces ub.poison vectors as part of a rewrite/conversion pattern in the Vector dialect. It replaces the arith.constant dense<0> vector initialization for vector.insert_slice ops with a poison vector.

This PR depends on all the previous PRs that introduced support for poison in Vector operations such as vector.shuffle, vector.extract, vector.insert, including ODS, canonicalization and lowering support.

This PR may improve end-to-end compilation time through LLVM, depending on the workloads.

dcaballe added a commit that referenced this pull request Feb 4, 2025
…125145)

This PR adds the UB to LLVM/SPIR-V conversion pass to some pipelines and
tests. This is in preparation to introducing the generation of
`ub.poison` in Vector dialect transformations (first one in #125613).
It should effectively be NFC at this point.
@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This is the first PR that introduces ub.poison vectors as part of a rewrite/conversion pattern in the Vector dialect. It replaces the arith.constant dense&lt;0&gt; vector initialization for vector.insert_slice ops with a poison vector.

This PR depends on all the previous PRs that introduced support for poison in Vector operations such as vector.shuffle, vector.extract, vector.insert, including ODS, canonicalization and lowering support.

This PR may improve end-to-end compilation time through LLVM, depending on the workloads.


Full diff: https://github.com/llvm/llvm-project/pull/125613.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+5-10)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir (+7-4)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+3-3)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir (+9-9)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+7-7)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 239dc9aa1de6fb..9c1e5fcee91de4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB//IR/UBOps.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -73,8 +73,7 @@ class ShapeCastOpNDDownCastRewritePattern
     SmallVector<int64_t> srcIdx(srcRank - 1, 0);
     SmallVector<int64_t> resIdx(resRank, 0);
     int64_t extractSize = sourceVectorType.getShape().back();
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
 
     // Compute the indices of each 1-D vector element of the source extraction
     // and destination slice insertion and generate such instructions.
@@ -129,8 +128,7 @@ class ShapeCastOpNDUpCastRewritePattern
     SmallVector<int64_t> srcIdx(srcRank, 0);
     SmallVector<int64_t> resIdx(resRank - 1, 0);
     int64_t extractSize = resultVectorType.getShape().back();
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; ++i) {
       if (i != 0) {
         incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
@@ -184,8 +182,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     // within the source and result shape.
     SmallVector<int64_t> srcIdx(srcRank, 0);
     SmallVector<int64_t> resIdx(resRank, 0);
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     for (int64_t i = 0; i < numElts; i++) {
       if (i != 0) {
         incIdx(srcIdx, sourceVectorType);
@@ -291,9 +288,7 @@ class ScalableShapeCastOpRewritePattern
     auto extractionVectorType = VectorType::get(
         {minExtractionSize}, sourceVectorType.getElementType(), {true});
 
-    Value result = rewriter.create<arith::ConstantOp>(
-        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-
+    Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
     SmallVector<int64_t> srcIdx(srcRank, 0);
     SmallVector<int64_t> resIdx(resRank, 0);
 
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index 043f9422d8790f..f1cc1354d1e3bc 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -83,17 +83,20 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
 // CHECK-LABEL: @transpose
 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
 func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2xi32>
+  // CHECK: %[[CST:.*]] = ub.poison : vector<1x2xi32>
   // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[CST1:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+  // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST1]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
   // CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[CST2:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST2]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
   // CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
-  // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST]] [0] : i32 into vector<2xi32>
+  // CHECK: %[[CST3:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+  // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST3]] [0] : i32 into vector<2xi32>
   // CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
   // CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
   // CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 4867a416e5d144..fd6895c01d78bd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -14,9 +14,9 @@
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
 // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//  CHECK-DAG:  %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
-//  CHECK-DAG:  %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32>
-//  CHECK-DAG:  %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+//  CHECK-DAG:  %[[vcst:.*]] = ub.poison : vector<8xf32>
+//  CHECK-DAG:  %[[vcst_0:.*]] = ub.poison : vector<12xf32>
+//  CHECK-DAG:  %[[vcst_1:.*]] = ub.poison : vector<2x3xf32>
 //      CHECK:  %[[a0:.*]] = vector.extract %[[A]][0] : vector<4xf32> from vector<2x4xf32>
 //      CHECK:  %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
 //      CHECK:  %[[a2:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index fde6ce91024464..b4518e57c39ddd 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -7,7 +7,7 @@
 // CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32>
 func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[8]xi32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
   // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
@@ -22,7 +22,7 @@ func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<
 // CHECK-LABEL: i32_1d_to_3d_last_dim_scalable
 // CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32>
 func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<2x1x[4]xi32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32>
@@ -37,7 +37,7 @@ func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x
 // CHECK-LABEL: i8_2d_to_1d_last_dim_scalable
 // CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8>
 func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[32]xi8>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[8]xi8> from vector<4x[8]xi8>
   // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[8]xi8> from vector<4x[8]xi8>
@@ -56,7 +56,7 @@ func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]
 // CHECK-LABEL: i8_1d_to_2d_last_dim_scalable
 // CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8>
 func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<4x[8]xi8>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8>
@@ -75,7 +75,7 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
 // CHECK-LABEL: f32_permute_leading_non_scalable_dims
 // CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
 func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<3x2x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
@@ -99,7 +99,7 @@ func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) ->
 // CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64>
 func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<4x[2]xf64>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf64> from vector<2x2x[2]xf64>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64>
   // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
@@ -109,7 +109,7 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
   // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
   // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64>
   %res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64>
-  // CHECK-NEXT: return %7 : vector<4x[2]xf64>
+  // CHECK-NEXT: return %[[res3:.*]] : vector<4x[2]xf64>
   return %res : vector<4x[2]xf64>
 }
 
@@ -119,7 +119,7 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
 // CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
 func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<6x[2]xf32>
   // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
   // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
@@ -146,7 +146,7 @@ func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<
 // CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
 func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
 {
-  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32>
+  // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<2x[4]xf32>
   // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
   // CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<[4]xf32> from vector<2x[4]xf32>
   // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index b4c52d5533116c..ee4fe59424a482 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -22,8 +22,8 @@ func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
 // llvm.matrix operations
 // CHECK-LABEL: func @shape_casts
 func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
-  // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
-  // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+  // CHECK-DAG: %[[cst22:.*]] = ub.poison : vector<2x2xf32>
+  // CHECK-DAG: %[[cst:.*]] = ub.poison : vector<4xf32>
   // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
   //
   // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
@@ -59,7 +59,7 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>)
 
 // CHECK-LABEL: func @shape_cast_2d2d
 // CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<2x3xf32>
 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
 // CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
 // CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
@@ -81,7 +81,7 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
 
 // CHECK-LABEL: func @shape_cast_3d1d
 // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<6xf32>
 // CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
 // CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
 // CHECK-SAME:           {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
@@ -100,7 +100,7 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
 
 // CHECK-LABEL: func @shape_cast_1d3d
 // CHECK-SAME: %[[A:.*]]: vector<6xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<2x1x3xf32>
 // CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
 // CHECK-SAME:           {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
 // CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
@@ -116,7 +116,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
 
 // CHECK-LABEL:   func.func @shape_cast_0d1d(
 // CHECK-SAME:                               %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:           %[[VAL_1:.*]] = ub.poison : vector<1xf32>
 // CHECK:           %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
 // CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
 // CHECK:           return %[[VAL_3]] : vector<1xf32>
@@ -129,7 +129,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
 
 // CHECK-LABEL:   func.func @shape_cast_1d0d(
 // CHECK-SAME:                               %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK:           %[[VAL_1:.*]] = ub.poison : vector<f32>
 // CHECK:           %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32>
 // CHECK:           %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
 // CHECK:           return %[[VAL_3]] : vector<f32>

github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 4, 2025
…ipelines (#125145)

This PR adds the UB to LLVM/SPIR-V conversion pass to some pipelines and
tests. This is in preparation to introducing the generation of
`ub.poison` in Vector dialect transformations (first one in llvm/llvm-project#125613).
It should effectively be NFC at this point.
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle ub.poison first. We (AMD) can take care of this, I expect we can have it landed by the end of the week.

@kuhar
Copy link
Member

kuhar commented Feb 4, 2025

cc: @andfau-amd

@@ -7,7 +7,7 @@
// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32>
func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32>
{
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32>
// CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[8]xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Here and other places - since this is no longer an arith.constant, perhaps let's rename the LIT variables? cst is just going to be confusing (and most likely copied by people).

cst -> poison? cst -> init ? cst -> ubp? cst -> undef.

Naming is hard 🤷🏻‍♂️ Avoiding cst should be a "win" regardless of the alternative.

@dcaballe dcaballe force-pushed the slice-to-poison-shuffle branch from d70f6e3 to b4d3cbe Compare February 6, 2025 18:40
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
// CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32>
// CHECK: %[[UB1:.*]] = vector.extract %[[UB]][0] : vector<2xi32> from vector<1x2xi32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe does this get folded when we run the canonicalizer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add it to the pattern to the convert to spirv pass so that we drop all these 2d vectors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have this pattern implemented but let me add it in a min...

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just realized that this breaks convert to spirv

@dcaballe
Copy link
Contributor Author

dcaballe commented Feb 6, 2025

vector.extract folds added here: #126122

@dcaballe dcaballe force-pushed the slice-to-poison-shuffle branch from b4d3cbe to b9b453c Compare February 6, 2025 23:04
@andfau-amd
Copy link
Contributor

(I made a comment just now but then deleted it because I realised I had missed something due to the way GitHub presents review comments, please ignore the notification email.)

This is the first PR that introduces `ub.poison` vectors as part of a
rewrite/conversion pattern in the Vector dialect. It replaces the
`arith.constant dense<0>` vector initialization for `vector.insert_slice`
ops with a poison vector.

This PR depends on all the previous PRs that introduced support for
poison in Vector operations such as `vector.shuffle`, `vector.extract`,
`vector.insert`, including ODS, canonicalization and lowering support.

This PR may improve end-to-end compilation time through LLVM, depending
on the workloads.
@dcaballe dcaballe force-pushed the slice-to-poison-shuffle branch from b9b453c to ce96c2a Compare February 7, 2025 18:26
@dcaballe
Copy link
Contributor Author

dcaballe commented Feb 7, 2025

Is SPIR-V ready to land this one already?

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar
Copy link
Member

kuhar commented Feb 7, 2025

Is SPIR-V ready to land this one already?

The updated test looks good. There's a small chance that we will need some other fold for the unrolling to work properly if we have gaps in test coverage. But IMO we can go ahead and fix forward if that turns out to be the case.

@dcaballe dcaballe merged commit 5a0075a into llvm:main Feb 7, 2025
6 of 7 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…lvm#125145)

This PR adds the UB to LLVM/SPIR-V conversion pass to some pipelines and
tests. This is in preparation to introducing the generation of
`ub.poison` in Vector dialect transformations (first one in llvm#125613).
It should effectively be NFC at this point.
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…lvm#125613)

This is the first PR that introduces `ub.poison` vectors as part of a
rewrite/conversion pattern in the Vector dialect. It replaces the
`arith.constant dense<0>` vector initialization for
`vector.insert_slice` ops with a poison vector.

This PR depends on all the previous PRs that introduced support for
poison in Vector operations such as `vector.shuffle`, `vector.extract`,
`vector.insert`, including ODS, canonicalization and lowering support.

This PR may improve end-to-end compilation time through LLVM, depending
on the workloads.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants