-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Support poison index when converting vector.insert/extract #125560
[mlir][spirv] Support poison index when converting vector.insert/extract #125560
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Andrea Faulds (andfau-amd) ChangesThis modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible. Resolves #124162. Full diff: https://github.com/llvm/llvm-project/pull/125560.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index af882cb1ca6e91..3481a2e8b7733c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -137,6 +137,26 @@ struct VectorBroadcastConvert final
}
};
+// SPIR-V does not have a concept of a poison index for certain instructions,
+// which creates a UB hazard when lowering from otherwise equivalent Vector
+// dialect instructions, because this index will be considered out-of-bounds.
+// To avoid this, this function implements a dynamic sanitization, arbitrarily
+// choosing to replace the poison index with index 0 (always in-bounds).
+static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
+ Location loc, Value dynamicIndex,
+ int64_t kPoisonIndex) {
+ Value poisonIndex = rewriter.create<spirv::ConstantOp>(
+ loc, dynamicIndex.getType(),
+ rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
+ Value cmpResult =
+ rewriter.create<spirv::IEqualOp>(loc, dynamicIndex, poisonIndex);
+ Value sanitizedIndex = rewriter.create<spirv::SelectOp>(
+ loc, cmpResult,
+ spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
+ dynamicIndex);
+ return sanitizedIndex;
+}
+
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
@@ -154,14 +174,26 @@ struct VectorExtractOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(extractOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, dstType, adaptor.getVector(),
- rewriter.getI32ArrayAttr(id.value()));
- else
+ getConstantIntValue(extractOp.getMixedPosition()[0])) {
+ // TODO: It would be better to apply the ub.poison folding for this case
+ // unconditionally, and have a specific SPIR-V lowering for it,
+ // rather than having to handle it here.
+ if (id == vector::ExtractOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero =
+ spirv::ConstantOp::getZero(dstType, extractOp.getLoc(), rewriter);
+ rewriter.replaceOp(extractOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, dstType, adaptor.getVector(),
+ rewriter.getI32ArrayAttr(id.value()));
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::ExtractOp::kPoisonIndex);
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, dstType, adaptor.getVector(),
- adaptor.getDynamicPosition()[0]);
+ extractOp, dstType, adaptor.getVector(), sanitizedIndex);
+ }
return success();
}
};
@@ -266,13 +298,25 @@ struct VectorInsertOpConvert final
}
if (std::optional<int64_t> id =
- getConstantIntValue(insertOp.getMixedPosition()[0]))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
- else
+ getConstantIntValue(insertOp.getMixedPosition()[0])) {
+ // TODO: It would be better to apply the ub.poison folding for this case
+ // unconditionally, and have a specific SPIR-V lowering for it,
+ // rather than having to handle it here.
+ if (id == vector::InsertOp::kPoisonIndex) {
+ // Arbitrary choice of poison result, intended to stick out.
+ Value zero = spirv::ConstantOp::getZero(insertOp.getDestVectorType(),
+ insertOp.getLoc(), rewriter);
+ rewriter.replaceOp(insertOp, zero);
+ } else
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+ } else {
+ Value sanitizedIndex = sanitizeDynamicIndex(
+ rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
+ vector::InsertOp::kPoisonIndex);
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, insertOp.getDest(), adaptor.getSource(),
- adaptor.getDynamicPosition()[0]);
+ insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 383215c016039a..35ef759cf24168 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -176,7 +176,8 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// -----
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
- // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0.000000e+00
+ // CHECK: return %[[ZERO]]
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
return %0: f32
}
@@ -208,7 +209,11 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f
// CHECK-LABEL: @extract_dynamic
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
-// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorExtractDynamic %[[V]][%[[SELECT]]] : vector<4xf32>, i32
func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
%0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
return %0: f32
@@ -264,8 +269,10 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_poison_idx
+// CHECK: %[[ZERO:.+]] = spirv.Constant dense<0.000000e+00>
+// CHECK: return %[[ZERO]]
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
- // expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
return %1: vector<4xf32>
}
@@ -306,7 +313,11 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id :
// CHECK-LABEL: @insert_dynamic
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
-// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+// CHECK: %[[POISON:.+]] = spirv.Constant -1 :
+// CHECK: %[[CMP:.+]] = spirv.IEqual %[[ID]], %[[POISON]]
+// CHECK: %[[ZERO:.+]] = spirv.Constant 0 :
+// CHECK: %[[SELECT:.+]] = spirv.Select %[[CMP]], %[[ZERO]], %[[ID]] :
+// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[SELECT]]] : vector<4xf32>, i32
func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
%0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
return %0: vector<4xf32>
|
This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible. Resolves llvm#124162.
9b14fc1
to
03a81e7
Compare
else | ||
getConstantIntValue(extractOp.getMixedPosition()[0])) { | ||
// TODO: ExtractOp::fold() already can fold a static poison index to | ||
// ub.poison; remove this once ub.poison can be converted to SPIR-V. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I intend to tackle ub.poison to SPIR-V conversion myself soon, and I'll clean this up in that commit.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…act (llvm#125560) This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible. Resolves llvm#124162.
This modifies the conversion patterns so that, in the case where the index is known statically to be poison, the insertion/extraction is replaced by an arbitrary junk constant value, and in the dynamic case, the index is sanitized at runtime. This avoids triggering a UB in both cases. The dynamic case is definitely a pessimisation of the generated code, but the use of dynamic indexes is expected to be very rare and already slow on real-world GPU compilers ingesting SPIR-V, so the impact should be negligible.
Resolves #124162.