Skip to content

Commit 5583441

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In #121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
1 parent e8ffbaa commit 5583441

File tree

1 file changed

+102
-32
lines changed

1 file changed

+102
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

+102-32
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,38 @@ struct ConvertVectorMaskedLoad final
753753
}
754754
};
755755

756+
/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
757+
///
758+
/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
759+
/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
760+
/// (a multi-byte scalar, e.g. i16), where N is some integer.
761+
///
762+
/// Put differently, this method checks whether this would be valid:
763+
///
764+
/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
765+
///
766+
/// EXAMPLES:
767+
/// * vector<4xi4> -> i16 - yes (N = 1)
768+
/// * vector<4xi4> -> i8 - yes (N = 2)
769+
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
770+
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
771+
static bool isSubByteVecFittable(VectorType subByteVecTy,
772+
Type multiByteScalarTy) {
773+
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
774+
775+
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
776+
int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
777+
778+
assert(subByteBits < 8 && "Not a sub-byte scalar type!");
779+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
780+
assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
781+
782+
int elemsPerMultiByte = multiByteBits / subByteBits;
783+
784+
// TODO: This is a bit too restrictive for vectors rank > 1.
785+
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
786+
}
787+
756788
//===----------------------------------------------------------------------===//
757789
// ConvertVectorTransferRead
758790
//===----------------------------------------------------------------------===//
@@ -787,7 +819,8 @@ struct ConvertVectorTransferRead final
787819
auto origElements = op.getVectorType().getNumElements();
788820

789821
// Note, per-element-alignment was already verified above.
790-
bool isFullyAligned = origElements % elementsPerContainerType == 0;
822+
bool isFullyAligned =
823+
isSubByteVecFittable(op.getVectorType(), newElementType);
791824

792825
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
793826
adaptor.getPadding());
@@ -1089,41 +1122,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
10891122
return commonConversionPrecondition(rewriter, preconditionType, op);
10901123
}
10911124

1092-
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1093-
/// means that:
1094-
/// 1. The `dstType` element type is a multiple of the
1095-
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1096-
/// is not supported). Let this multiple be `N`.
1097-
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1098-
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1099-
/// not supported).
1125+
/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1126+
///
1127+
/// Alignment means that `subByteVecTy` can be packed into a vector of
1128+
/// `containerTy` elements. More specifically:
1129+
/// 1. The bit-width of `containerTy` is a multiple of the
1130+
/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1131+
/// this multiple is 4.
1132+
/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1133+
/// elements in `subByteVecTy`.
1134+
///
1135+
/// EXAMPLE 1:
1136+
/// `subByteVecTy = vector<2xi4>`, and
1137+
/// `containerTy = i16`
1138+
///
1139+
/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1140+
///
1141+
/// EXAMPLE 2:
1142+
/// `subByteVecTy = vector<3xi4>`, and
1143+
/// `containerTy = i16`
1144+
///
1145+
/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1146+
///
1147+
/// EXAMPLE 3:
1148+
/// `subByteVecTy = vector<3xi3>`, and
1149+
/// `containerTy = i16`
1150+
///
1151+
/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
11001152
///
11011153
/// NOTE: This method assumes that common conversion preconditions are met. In
1102-
/// particular, the element type of `dstType` is assumed to be a multi-byte
1103-
/// type (e.g. i8, i16, i32).
1154+
/// particular, `containerTy` is assumed to be a
1155+
/// multi-byte scalar type (e.g., i8, i16, i32).
11041156
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1105-
VectorType subByteVecType,
1106-
VectorType dstType,
1157+
VectorType subByteVecTy,
1158+
Type containerTy,
11071159
Operation *op) {
1108-
if (!subByteVecType || !dstType)
1109-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1110-
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1111-
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1160+
// TODO: This is validating the inputs rather than checking the conditions
1161+
// documented above. Replace with an assert.
1162+
if (!subByteVecTy)
1163+
return rewriter.notifyMatchFailure(op, "not a vector!");
11121164

1113-
if (dstElemBitwidth < 8)
1114-
return rewriter.notifyMatchFailure(
1115-
op, "the bitwidth of dstType must be greater than or equal to 8");
1116-
if (dstElemBitwidth % srcElemBitwidth != 0)
1117-
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1118-
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1165+
// TODO: This is validating the inputs rather than checking the conditions
1166+
// documented above. Replace with an assert.
1167+
if (!containerTy.isIntOrFloat())
1168+
return rewriter.notifyMatchFailure(op, "not a scalar!");
1169+
1170+
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1171+
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
1172+
1173+
// Enforced by the common pre-conditions.
1174+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1175+
1176+
// TODO: Remove this condition - the assert above (and
1177+
// commonConversionPrecondtion) takes care of that.
1178+
if (multiByteBits < 8)
1179+
return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
1180+
1181+
// TODO: Add support other widths (when/if needed)
1182+
if (subByteBits != 2 && subByteBits != 4)
11191183
return rewriter.notifyMatchFailure(
1120-
op, "only src bitwidth of 2 or 4 is supported at this moment");
1184+
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1185+
1186+
// Condition 1.
1187+
if (multiByteBits % subByteBits != 0)
1188+
return rewriter.notifyMatchFailure(op, "unalagined element types");
11211189

1122-
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1123-
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1190+
// Condition 2.
1191+
if (!isSubByteVecFittable(subByteVecTy, containerTy))
11241192
return rewriter.notifyMatchFailure(
1125-
op, "the trailing dimension of the input vector of sub-bytes must be a "
1126-
"multiple of 8 / <sub-byte-width>");
1193+
op, "not possible to fit this sub-byte vector type into a vector of "
1194+
"the given multi-byte type");
11271195

11281196
return success();
11291197
}
@@ -1560,8 +1628,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15601628
return failure();
15611629

15621630
// Check general alignment preconditions.
1563-
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1564-
conversionOp)))
1631+
Type containerType = rewriter.getI8Type();
1632+
if (failed(alignedConversionPrecondition(rewriter, srcVecType,
1633+
containerType, conversionOp)))
15651634
return failure();
15661635

15671636
// Perform the rewrite.
@@ -1625,8 +1694,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
16251694

16261695
// Check general alignment preconditions. We invert the src/dst type order
16271696
// to reuse the existing precondition logic.
1628-
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1629-
truncOp)))
1697+
Type containerType = rewriter.getI8Type();
1698+
if (failed(alignedConversionPrecondition(rewriter, dstVecType,
1699+
containerType, truncOp)))
16301700
return failure();
16311701

16321702
// Create a new iX -> i8 truncation op.

0 commit comments

Comments
 (0)