Skip to content

Commit 7792e1d

Browse files
committed
[mlir][vector] Extend vector.{insert|extract}_strided_slice
Extends `vector.insert_strided_slice` and `vector.insert_strided_slice` to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively). This is supported: ```mlir vector.extract_strided_slice %1 { offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32> ``` This is not supported: ```mlir vector.extract_strided_slice %1 { offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32> ```
1 parent e3172e8 commit 7792e1d

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -3194,6 +3194,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
31943194
// Inference works as follows:
31953195
// 1. Add 'sizes' from prefix of dims in 'offsets'.
31963196
// 2. Add sizes from 'vectorType' for remaining dims.
3197+
// Scalable flags are inherited from 'vectorType'.
31973198
static Type inferStridedSliceOpResultType(VectorType vectorType,
31983199
ArrayAttr offsets, ArrayAttr sizes,
31993200
ArrayAttr strides) {
@@ -3206,7 +3207,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
32063207
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
32073208
shape.push_back(vectorType.getShape()[idx]);
32083209

3209-
return VectorType::get(shape, vectorType.getElementType());
3210+
return VectorType::get(shape, vectorType.getElementType(),
3211+
vectorType.getScalableDims());
32103212
}
32113213

32123214
void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3267,20 @@ LogicalResult ExtractStridedSliceOp::verify() {
32653267
if (getResult().getType() != resultType)
32663268
return emitOpError("expected result type to be ") << resultType;
32673269

3270+
unsigned idx = 0;
3271+
for (unsigned ub = sizes.size(); idx < ub; ++idx) {
3272+
if (type.getScalableDims()[idx]) {
3273+
auto inputDim = type.getShape()[idx];
3274+
auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
3275+
if (inputDim != inputSize)
3276+
return emitOpError("expected size at idx=")
3277+
<< idx
3278+
<< (" to match the corresponding base size from the input "
3279+
"vector (")
3280+
<< inputSize << (" vs ") << inputDim << (")");
3281+
}
3282+
}
3283+
32683284
return success();
32693285
}
32703286

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+43
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,29 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
11421142

11431143
// -----
11441144

1145+
func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
1146+
%0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
1147+
return %0 : vector<1x1x[4]xi32>
1148+
}
1149+
1150+
// CHECK-LABEL: func.func @extract_strided_slice_scalable(
1151+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
1152+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
1153+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
1154+
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
1155+
// CHECK: %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
1156+
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1157+
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
1158+
// CHECK: %[[VAL_7:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
1159+
// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
1160+
// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1161+
// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.array<1 x vector<[4]xi32>>
1162+
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_4]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
1163+
// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
1164+
// CHECK: return %[[VAL_12]] : vector<1x1x[4]xi32>
1165+
1166+
// -----
1167+
11451168
func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
11461169
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
11471170
return %0 : vector<4x4x4xf32>
@@ -1207,6 +1230,26 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
12071230

12081231
// -----
12091232

1233+
func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
1234+
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
1235+
return %0 : vector<1x4x[4]xi32>
1236+
}
1237+
// CHECK-LABEL: func.func @insert_strided_slice_scalable(
1238+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x[4]xi32>,
1239+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
1240+
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
1241+
// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
1242+
// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
1243+
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1244+
// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_2]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
1245+
// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_3]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1246+
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3] : !llvm.array<4 x vector<[4]xi32>>
1247+
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
1248+
// CHECK: %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
1249+
// CHECK: return %[[VAL_10]] : vector<1x4x[4]xi32>
1250+
1251+
// -----
1252+
12101253
func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
12111254
// CHECK-LABEL: @vector_fma
12121255
// CHECK-SAME: %[[A:.*]]: vector<8xf32>

mlir/test/Dialect/Vector/invalid.mlir

+8
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
687687

688688
// -----
689689

690+
func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
691+
// expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
692+
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
693+
return %1 : vector<1x1x[2]xi32>
694+
}
695+
696+
// -----
697+
690698
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
691699
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
692700
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>

mlir/test/Dialect/Vector/ops.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
326326
return %1: vector<2x2x16xf32>
327327
}
328328

329+
// CHECK-LABEL: @extract_strided_slice_scalable
330+
func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
331+
// CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
332+
%1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
333+
return %1: vector<2x[8]x16xf32>
334+
}
335+
329336
#contraction_to_scalar_accesses = [
330337
affine_map<(i) -> (i)>,
331338
affine_map<(i) -> (i)>,

0 commit comments

Comments
 (0)