Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Support poison index when converting vector.insert/extract #125560

Merged

Conversation

andfau-amd
Copy link
Contributor

This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible.

Resolves #124162.

@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

Changes

This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible.

Resolves #124162.


Full diff: https://github.com/llvm/llvm-project/pull/125560.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+57-13)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+15-4)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb1ca6e91..3481a2e8b7733c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
   }
 };
 
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization, arbitrarily
+// choosing to replace the poison index with index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value dynamicIndex,
+                                  int64_t kPoisonIndex) {
+  Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+      loc, dynamicIndex.getType(),
+      rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+  Value cmpResult =
+      rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+  Value sanitizedIndex = rewriter.create<spirv::SelectOp>(
+      loc, cmpResult,
+      spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+      dynamicIndex);
+  return sanitizedIndex;
+}
+
 struct VectorExtractOpConvert final
     : public OpConversionPattern<vector::ExtractOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
     }
 
     if (std::optional<int64_t> id =
-            getConstantIntValue(extractOp.getMixedPosition()[0]))
-      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-          extractOp, dstType, adaptor.getVector(),
-          rewriter.getI32ArrayAttr(id.value()));
-    else
+            getConstantIntValue(extractOp.getMixedPosition()[0])) {
+      // TODO: It would be better to apply the ub.poison folding for this case
+      //       unconditionally, and have a specific SPIR-V lowering for it,
+      //       rather than having to handle it here.
+      if (id == vector::ExtractOp::kPoisonIndex) {
+        // Arbitrary choice of poison result, intended to stick out.
+        Value zero =
+            spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+        rewriter.replaceOp(extractOp, zero);
+      } else
+        rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+            extractOp, dstType, adaptor.getVector(),
+            rewriter.getI32ArrayAttr(id.value()));
+    } else {
+      Value sanitizedIndex = sanitizeDynamicIndex(
+          rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+          vector::ExtractOp::kPoisonIndex);
       rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
-          extractOp, dstType, adaptor.getVector(),
-          adaptor.getDynamicPosition()[0]);
+          extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+    }
     return success();
   }
 };
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
     }
 
     if (std::optional<int64_t> id =
-            getConstantIntValue(insertOp.getMixedPosition()[0]))
-      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
-    else
+            getConstantIntValue(insertOp.getMixedPosition()[0])) {
+      // TODO: It would be better to apply the ub.poison folding for this case
+      //       unconditionally, and have a specific SPIR-V lowering for it,
+      //       rather than having to handle it here.
+      if (id == vector::InsertOp::kPoisonIndex) {
+        // Arbitrary choice of poison result, intended to stick out.
+        Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+                                                insertOp.getLoc(), rewriter);
+        rewriter.replaceOp(insertOp, zero);
+      } else
+        rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+            insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    } else {
+      Value sanitizedIndex = sanitizeDynamicIndex(
+          rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+          vector::InsertOp::kPoisonIndex);
       rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
-          insertOp, insertOp.getDest(), adaptor.getSource(),
-          adaptor.getDynamicPosition()[0]);
+          insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+    }
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c016039a..35ef759cf24168 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
 // -----
 
 func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
-  // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+  // CHECK: return %[[ZERO]]
   %0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
   return %0: f32
 }
@@ -208,7 +209,11 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
 // CHECK-LABEL: @extract_dynamic
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
 //       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+//       CHECK:   %[[POISON:.+]] = spirv.Constant -1 :
+//       CHECK:   %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 :
+//       CHECK:   %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<4xf32>, i32
 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
   %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
   return %0: f32
@@ -264,8 +269,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
 func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
-  // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
   %1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
   return %1: vector<4xf32>
 }
@@ -306,7 +313,11 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
 // CHECK-LABEL: @insert_dynamic
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
 //       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+//       CHECK:   %[[POISON:.+]] = spirv.Constant -1 :
+//       CHECK:   %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+//       CHECK:   %[[ZERO:.+]] = spirv.Constant 0 :
+//       CHECK:   %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<4xf32>, i32
 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
   %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
   return %0: vector<4xf32>

This modifies the conversion patterns so that, in the case where the
index is known statically to be poison, the insertion/extraction is
replaced by an arbitrary junk constant value, and in the dynamic case,
the index is sanitized at runtime. This avoids triggering a UB in both
cases. The dynamic case is definitely a pessimisation of the generated
code, but the use of dynamic indexes is expected to be very rare and
already slow on real-world GPU compilers ingesting SPIR-V, so the impact
should be negligible.

Resolves llvm#124162.
@andfau-amd andfau-amd force-pushed the 124162-CompositeExtractInsert-poison branch from 9b14fc1 to 03a81e7 Compare February 4, 2025 18:48
@andfau-amd andfau-amd requested a review from kuhar February 4, 2025 18:51
else
getConstantIntValue(extractOp.getMixedPosition()[0])) {
// TODO: ExtractOp::fold() already can fold a static poison index to
// ub.poison; remove this once ub.poison can be converted to SPIR-V.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I intend to tackle ub.poison to SPIR-V conversion myself soon, and I'll clean this up in that commit.)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@andfau-amd andfau-amd merged commit 5df62bd into llvm:main Feb 5, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…act (llvm#125560)

This modifies the conversion patterns so that, in the case where the
index is known statically to be poison, the insertion/extraction is
replaced by an arbitrary junk constant value, and in the dynamic case,
the index is sanitized at runtime. This avoids triggering a UB in both
cases. The dynamic case is definitely a pessimisation of the generated
code, but the use of dynamic indexes is expected to be very rare and
already slow on real-world GPU compilers ingesting SPIR-V, so the impact
should be negligible.

Resolves llvm#124162.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Add support for poison index values to spirv.CompositeExtract/InsertOp
3 participants