Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Refactor VectorEmulateNarrowType.cpp #123529

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 112 additions & 49 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {

auto origElements = valueToStore.getType().getNumElements();
// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedNumFrontPadElems =
isFullyAligned ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
isDivisibleInSize ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
Expand All @@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
!isFullyAligned || *foldedNumFrontPadElems != 0;
!isDivisibleInSize || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
Expand Down Expand Up @@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {

auto origElements = op.getVectorType().getNumElements();
// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedIntraVectorOffset =
isFullyAligned ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
isDivisibleInSize ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

// Always load enough elements which can cover the original elements.
int64_t maxintraDataOffset =
Expand All @@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (!isFullyAligned) {
} else if (!isDivisibleInSize) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
Expand Down Expand Up @@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedIntraVectorOffset =
isFullyAligned ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
isDivisibleInSize ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
Expand All @@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
} else if (!isFullyAligned) {
} else if (!isDivisibleInSize) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
Expand Down Expand Up @@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
} else if (!isFullyAligned) {
} else if (!isDivisibleInSize) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
Expand All @@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
} else if (!isFullyAligned) {
} else if (!isDivisibleInSize) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
Expand All @@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
}
};

/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
///
/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
/// (a multi-byte scalar, e.g. i16), where N is some integer.
///
/// Put differently, this method checks whether this would be valid:
///
/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
///
/// EXAMPLES:
/// * vector<4xi4> -> i16 - yes (N = 1)
/// * vector<4xi4> -> i8 - yes (N = 2)
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
Type multiByteScalarTy) {
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");

int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();

assert(subByteBits < 8 && "Not a sub-byte scalar type!");
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");

int elemsPerMultiByte = multiByteBits / subByteBits;

// TODO: This is a bit too restrictive for vectors rank > 1.
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
}

//===----------------------------------------------------------------------===//
// ConvertVectorTransferRead
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();

// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
bool isDivisibleInSize =
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);

auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
adaptor.getPadding());
Expand All @@ -1146,8 +1179,8 @@ struct ConvertVectorTransferRead final
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedIntraVectorOffset =
isFullyAligned ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
isDivisibleInSize ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
Expand All @@ -1171,7 +1204,7 @@ struct ConvertVectorTransferRead final
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
} else if (!isFullyAligned) {
} else if (!isDivisibleInSize) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
Expand Down Expand Up @@ -1428,41 +1461,69 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
return commonConversionPrecondition(rewriter, preconditionType, op);
}

/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
/// means that:
/// 1. The `dstType` element type is a multiple of the
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
/// is not supported). Let this multiple be `N`.
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
/// not supported).
/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
///
/// Alignment means that `subByteVecTy` can be packed into a vector of
/// `containerTy` elements. More specifically:
/// 1. The bit-width of `containerTy` is a multiple of the
/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
/// this multiple is 4.
/// 2. The multiple from 1. above divides evenly the number of the (trailing)
/// elements in `subByteVecTy`.
///
/// EXAMPLE 1:
/// `subByteVecTy = vector<2xi4>`, and
/// `containerTy = i16`
///
/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
///
/// EXAMPLE 2:
/// `subByteVecTy = vector<3xi4>`, and
/// `containerTy = i16`
///
/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
///
/// EXAMPLE 3:
/// `subByteVecTy = vector<3xi3>`, and
/// `containerTy = i16`
///
/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
///
/// NOTE: This method assumes that common conversion preconditions are met. In
/// particular, the element type of `dstType` is assumed to be a multi-byte
/// type (e.g. i8, i16, i32).
/// particular, `containerTy` is assumed to be a
/// multi-byte scalar type (e.g., i8, i16, i32).
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
VectorType subByteVecType,
VectorType dstType,
VectorType subByteVecTy,
Type containerTy,
Operation *op) {
if (!subByteVecType || !dstType)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
assert(containerTy.isIntOrFloat() &&
"container element type is not a scalar");

if (dstElemBitwidth < 8)
return rewriter.notifyMatchFailure(
op, "the bitwidth of dstType must be greater than or equal to 8");
if (dstElemBitwidth % srcElemBitwidth != 0)
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
// TODO: This is validating the inputs rather than checking the conditions
// documented above. Replace with an assert.
if (!subByteVecTy)
return rewriter.notifyMatchFailure(op, "not a vector!");

unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();

// Enforced by the common pre-conditions.
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");

// TODO: Add support other widths (when/if needed)
if (subByteBits != 2 && subByteBits != 4)
return rewriter.notifyMatchFailure(
op, "only src bitwidth of 2 or 4 is supported at this moment");
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no single test for 1 bit :)

Basically, you've contributed i4 emulation and then @ziereis added i2 emulation. We check specifically for i2 and i4 as that's what we've focused on so far.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know some downstream projects try to use 1bit, and I think upstream shouldn't trivially block it in this way. They can contribute i1 tests for sure but overall the code here should support 1-bit scenarios without problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know some downstream projects try to use 1bit, and I think upstream shouldn't trivially block it in this way.

Oh, definitely not trying to block anyone. This is merely trying to document the existing assumptions. Note that this condition is already present:

if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
return rewriter.notifyMatchFailure(
op, "only src bitwidth of 2 or 4 is supported at this moment");

They can contribute i1 tests for sure but overall the code here should support 1-bit scenarios without problem.

They would be welcome with praise and gratitude :)


// Condition 1 ("per-element" alignment)
if (multiByteBits % subByteBits != 0)
return rewriter.notifyMatchFailure(op, "unalagined element types");

const int numSrcElemsPerByte = 8 / srcElemBitwidth;
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
// Condition 2 ("full" alignment)
if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
return rewriter.notifyMatchFailure(
op, "the trailing dimension of the input vector of sub-bytes must be a "
"multiple of 8 / <sub-byte-width>");
op, "not possible to fit this sub-byte vector type into a vector of "
"the given multi-byte type");

return success();
}
Expand Down Expand Up @@ -1899,8 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();

// Check general alignment preconditions.
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
conversionOp)))
if (failed(alignedConversionPrecondition(
rewriter, srcVecType,
/*containerTy=*/rewriter.getI8Type(), conversionOp)))
return failure();

// Perform the rewrite.
Expand Down Expand Up @@ -1964,8 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {

// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
truncOp)))
if (failed(alignedConversionPrecondition(
rewriter, dstVecType,
/*containerTy=*/rewriter.getI8Type(), truncOp)))
return failure();

// Create a new iX -> i8 truncation op.
Expand Down