-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Generate poison vectors in vector.shape_cast lowering #125613
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis is the first PR that introduces This PR depends on all the previous PRs that introduced support for poison in Vector operations such as This PR may improve end-to-end compilation time through LLVM, depending on the workloads. Full diff: https://github.com/llvm/llvm-project/pull/125613.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 239dc9aa1de6fb..9c1e5fcee91de4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -11,8 +11,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/UB//IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -73,8 +73,7 @@ class ShapeCastOpNDDownCastRewritePattern
SmallVector<int64_t> srcIdx(srcRank - 1, 0);
SmallVector<int64_t> resIdx(resRank, 0);
int64_t extractSize = sourceVectorType.getShape().back();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
// Compute the indices of each 1-D vector element of the source extraction
// and destination slice insertion and generate such instructions.
@@ -129,8 +128,7 @@ class ShapeCastOpNDUpCastRewritePattern
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank - 1, 0);
int64_t extractSize = resultVectorType.getShape().back();
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; ++i) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
@@ -184,8 +182,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
// within the source and result shape.
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank, 0);
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
for (int64_t i = 0; i < numElts; i++) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType);
@@ -291,9 +288,7 @@ class ScalableShapeCastOpRewritePattern
auto extractionVectorType = VectorType::get(
{minExtractionSize}, sourceVectorType.getElementType(), {true});
- Value result = rewriter.create<arith::ConstantOp>(
- loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-
+ Value result = rewriter.create<ub::PoisonOp>(loc, resultVectorType);
SmallVector<int64_t> srcIdx(srcRank, 0);
SmallVector<int64_t> resIdx(resRank, 0);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index 043f9422d8790f..f1cc1354d1e3bc 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -83,17 +83,20 @@ func.func @vaddi_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>) -> (i32
// CHECK-LABEL: @transpose
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>)
func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2xi32>
+ // CHECK: %[[CST:.*]] = ub.poison : vector<1x2xi32>
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32>
- // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32>
+ // CHECK: %[[CST1:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+ // CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST1]] [0] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG1]][0] : i32 from vector<3xi32>
// CHECK: %[[INSERT1:.*]] = vector.insert %[[EXTRACT1]], %[[INSERT0]][1] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1] : i32 from vector<3xi32>
- // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST]] [0] : i32 into vector<2xi32>
+ // CHECK: %[[CST2:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert %[[EXTRACT2]], %[[CST2]] [0] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG1]][1] : i32 from vector<3xi32>
// CHECK: %[[INSERT3:.*]] = vector.insert %[[EXTRACT3]], %[[INSERT2]] [1] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][2] : i32 from vector<3xi32>
- // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST]] [0] : i32 into vector<2xi32>
+ // CHECK: %[[CST3:.*]] = vector.extract %[[CST]][0] : vector<2xi32> from vector<1x2xi32>
+ // CHECK: %[[INSERT4:.*]] = vector.insert %[[EXTRACT4]], %[[CST3]] [0] : i32 into vector<2xi32>
// CHECK: %[[EXTRACT5:.*]] = vector.extract %[[ARG1]][2] : i32 from vector<3xi32>
// CHECK: %[[INSERT5:.*]] = vector.insert %[[EXTRACT5]], %[[INSERT4]] [1] : i32 into vector<2xi32>
// CHECK: return %[[INSERT1]], %[[INSERT3]], %[[INSERT5]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 4867a416e5d144..fd6895c01d78bd 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -14,9 +14,9 @@
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-// CHECK-DAG: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
-// CHECK-DAG: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32>
-// CHECK-DAG: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK-DAG: %[[vcst:.*]] = ub.poison : vector<8xf32>
+// CHECK-DAG: %[[vcst_0:.*]] = ub.poison : vector<12xf32>
+// CHECK-DAG: %[[vcst_1:.*]] = ub.poison : vector<2x3xf32>
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<4xf32> from vector<2x4xf32>
// CHECK: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32>
// CHECK: %[[a2:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
index fde6ce91024464..b4518e57c39ddd 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -7,7 +7,7 @@
// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32>
func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32>
{
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[8]xi32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
// CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
@@ -22,7 +22,7 @@ func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<
// CHECK-LABEL: i32_1d_to_3d_last_dim_scalable
// CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32>
func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> {
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<2x1x[4]xi32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32>
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32>
@@ -37,7 +37,7 @@ func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x
// CHECK-LABEL: i8_2d_to_1d_last_dim_scalable
// CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8>
func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> {
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[32]xi8>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[8]xi8> from vector<4x[8]xi8>
// CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[8]xi8> from vector<4x[8]xi8>
@@ -56,7 +56,7 @@ func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]
// CHECK-LABEL: i8_1d_to_2d_last_dim_scalable
// CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8>
func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> {
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<4x[8]xi8>
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8>
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8>
// CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8>
@@ -75,7 +75,7 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
// CHECK-LABEL: f32_permute_leading_non_scalable_dims
// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<3x2x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
@@ -99,7 +99,7 @@ func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) ->
// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64>
func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64>
{
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<4x[2]xf64>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf64> from vector<2x2x[2]xf64>
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64>
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
@@ -109,7 +109,7 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf64> from vector<2x2x[2]xf64>
// CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64>
%res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64>
- // CHECK-NEXT: return %7 : vector<4x[2]xf64>
+ // CHECK-NEXT: return %[[res3:.*]] : vector<4x[2]xf64>
return %res : vector<4x[2]xf64>
}
@@ -119,7 +119,7 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
{
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<6x[2]xf32>
// CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
@@ -146,7 +146,7 @@ func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<
// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
{
- // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32>
+ // CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<2x[4]xf32>
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
// CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<[4]xf32> from vector<2x[4]xf32>
// CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32>
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index b4c52d5533116c..ee4fe59424a482 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -22,8 +22,8 @@ func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
// llvm.matrix operations
// CHECK-LABEL: func @shape_casts
func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
- // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
- // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // CHECK-DAG: %[[cst22:.*]] = ub.poison : vector<2x2xf32>
+ // CHECK-DAG: %[[cst:.*]] = ub.poison : vector<4xf32>
// CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
//
// CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
@@ -59,7 +59,7 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>)
// CHECK-LABEL: func @shape_cast_2d2d
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
@@ -81,7 +81,7 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
// CHECK-LABEL: func @shape_cast_3d1d
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<6xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
@@ -100,7 +100,7 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
// CHECK-LABEL: func @shape_cast_1d3d
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
-// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
+// CHECK: %[[C:.*]] = ub.poison : vector<2x1x3xf32>
// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
@@ -116,7 +116,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
// CHECK-LABEL: func.func @shape_cast_0d1d(
// CHECK-SAME: %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[VAL_1:.*]] = ub.poison : vector<1xf32>
// CHECK: %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
// CHECK: return %[[VAL_3]] : vector<1xf32>
@@ -129,7 +129,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
// CHECK-LABEL: func.func @shape_cast_1d0d(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK: %[[VAL_1:.*]] = ub.poison : vector<f32>
// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32>
// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
// CHECK: return %[[VAL_3]] : vector<f32>
|
…ipelines (#125145) This PR adds the UB to LLVM/SPIR-V conversion pass to some pipelines and tests. This is in preparation to introducing the generation of `ub.poison` in Vector dialect transformations (first one in llvm/llvm-project#125613). It should effectively be NFC at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to handle ub.poison
first. We (AMD) can take care of this, I expect we can have it landed by the end of the week.
cc: @andfau-amd |
@@ -7,7 +7,7 @@ | |||
// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32> | |||
func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32> | |||
{ | |||
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32> | |||
// CHECK-NEXT: %[[cst:.*]] = ub.poison : vector<[8]xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Here and other places - since this is no longer an arith.constant
, perhaps let's rename the LIT variables? cst
is just going to be confusing (and most likely copied by people).
cst
-> poison
? cst
-> init
? cst
-> ubp
? cst
-> undef
.
Naming is hard 🤷🏻♂️ Avoiding cst
should be a "win" regardless of the alternative.
d70f6e3
to
b4d3cbe
Compare
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : i32 from vector<3xi32> | ||
// CHECK: %[[INSERT0:.*]]= vector.insert %[[EXTRACT0]], %[[CST]] [0] : i32 into vector<2xi32> | ||
// CHECK: %[[UB1:.*]] = vector.extract %[[UB]][0] : vector<2xi32> from vector<1x2xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcaballe does this get folded when we run the canonicalizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add it to the pattern to the convert to spirv pass so that we drop all these 2d vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have this pattern implemented but let me add it in a min...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just realized that this breaks convert to spirv
|
b4d3cbe
to
b9b453c
Compare
(I made a comment just now but then deleted it because I realised I had missed something due to the way GitHub presents review comments, please ignore the notification email.) |
This is the first PR that introduces `ub.poison` vectors as part of a rewrite/conversion pattern in the Vector dialect. It replaces the `arith.constant dense<0>` vector initialization for `vector.insert_slice` ops with a poison vector. This PR depends on all the previous PRs that introduced support for poison in Vector operations such as `vector.shuffle`, `vector.extract`, `vector.insert`, including ODS, canonicalization and lowering support. This PR may improve end-to-end compilation time through LLVM, depending on the workloads.
b9b453c
to
ce96c2a
Compare
Is SPIR-V ready to land this one already? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The updated test looks good. There's a small chance that we will need some other fold for the unrolling to work properly if we have gaps in test coverage. But IMO we can go ahead and fix forward if that turns out to be the case. |
…lvm#125145) This PR adds the UB to LLVM/SPIR-V conversion pass to some pipelines and tests. This is in preparation to introducing the generation of `ub.poison` in Vector dialect transformations (first one in llvm#125613). It should effectively be NFC at this point.
…lvm#125613) This is the first PR that introduces `ub.poison` vectors as part of a rewrite/conversion pattern in the Vector dialect. It replaces the `arith.constant dense<0>` vector initialization for `vector.insert_slice` ops with a poison vector. This PR depends on all the previous PRs that introduced support for poison in Vector operations such as `vector.shuffle`, `vector.extract`, `vector.insert`, including ODS, canonicalization and lowering support. This PR may improve end-to-end compilation time through LLVM, depending on the workloads.
This is the first PR that introduces
ub.poison
vectors as part of a rewrite/conversion pattern in the Vector dialect. It replaces thearith.constant dense<0>
vector initialization forvector.insert_slice
ops with a poison vector.This PR depends on all the previous PRs that introduced support for poison in Vector operations such as
vector.shuffle
,vector.extract
,vector.insert
, including ODS, canonicalization and lowering support.This PR may improve end-to-end compilation time through LLVM, depending on the workloads.