Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Refactor VectorEmulateNarrowType.cpp #123529

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jan 19, 2025

This is PR refactors alignedConversionPrecondition and adds new helper
hooks.

Update alignedConversionPrecondition (1)

This method doesn't require the vector type for the "container" argument. The
underlying element type is sufficient. The corresponding argument has been
renamed as containerTy - this is meant as the multi-byte container element
type (i8, i16, i32, etc). With this change, the updated invocations of
alignedConversionPrecondition (in e.g. RewriteAlignedSubByteIntExt) make it
clear that the container element type is assumed to be i8.

Update alignedConversionPrecondition (2):

The final check in alignedConversionPrecondition has been replaced with a new
helper method, isSubByteVecFittable. This helper hook is now also re-used in
ConvertVectorTransferRead (to improve code re-use).

Other updates

Extended + unified comments.

Implements: #123630

@llvmbot
Copy link
Member

llvmbot commented Jan 19, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)
  • [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)
  • [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N)
  • [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)

Patch is 33.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123529.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+234-136)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..373b8a8822318f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,6 +45,10 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+//===----------------------------------------------------------------------===//
+// Utils
+//===----------------------------------------------------------------------===//
+
 /// Returns a compressed mask for the emulated vector. For example, when
 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
 /// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -282,13 +286,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
                    OpFoldResult linearizedIndices,
                    int64_t numEmultedElementsToLoad, Type origElemType,
                    Type emulatedElemType) {
-  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
-               origElemType.getIntOrFloatBitWidth();
+  auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
+                                  origElemType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
       loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+      loc,
+      VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
+                      origElemType),
       newLoad);
 }
 
@@ -298,6 +304,7 @@ namespace {
 // ConvertVectorStore
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -314,14 +321,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +344,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -346,13 +353,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     OpFoldResult linearizedIndices;
     std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
+    auto numElements = origElements / elementsPerContainerType;
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc, VectorType::get(numElements, newElementType),
         op.getValueToStore());
@@ -368,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 // ConvertVectorMaskedStore
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorMaskedStore final
     : OpConversionPattern<vector::MaskedStoreOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -385,17 +393,17 @@ struct ConvertVectorMaskedStore final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
 
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
     int origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -404,7 +412,7 @@ struct ConvertVectorMaskedStore final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -444,12 +452,13 @@ struct ConvertVectorMaskedStore final
     //
     // FIXME: Make an example based on the comment above work (see #115460 for
     // reproducer).
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements = (origElements + elementsPerContainerType - 1) /
+                       elementsPerContainerType;
     auto newType = VectorType::get(numElements, newElementType);
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +467,8 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitCastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -477,6 +487,7 @@ struct ConvertVectorMaskedStore final
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -493,14 +504,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -532,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned = origElements % elementsPerContainerType == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -541,21 +553,21 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraDataOffset)
-            : 0;
+        isFullyAligned ? 0
+                       : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     // Always load enough elements which can cover the original elements.
-    int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    int64_t maxintraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+                                        elementsPerContainerType);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
                            numElements, oldElementType, newElementType);
@@ -566,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
       result = dynamicallyExtractSubVector(
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
-    } else if (isUnalignedEmulation) {
+    } else if (!isFullyAligned) {
       result =
           staticallyExtractSubvector(rewriter, loc, op.getType(), result,
                                      *foldedIntraVectorOffset, origElements);
@@ -580,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 // ConvertVectorMaskedLoad
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorMaskedLoad final
     : OpConversionPattern<vector::MaskedLoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -596,14 +609,14 @@ struct ConvertVectorMaskedLoad final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -649,7 +662,7 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -657,7 +670,7 @@ struct ConvertVectorMaskedLoad final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -668,18 +681,21 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    FailureOr<Operation *> newMask = getCompressedMaskOp(
-        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+                            elementsPerContainerType, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
     Value passthru = op.getPassThru();
 
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
     auto loadType = VectorType::get(numElements, newElementType);
-    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitcastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +722,8 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    auto newSelectMaskType =
-        VectorType::get(numElements * scale, rewriter.getI1Type());
+    auto newSelectMaskType = VectorType::get(
+        numElements * elementsPerContainerType, rewriter.getI1Type());
     // TODO: try to fold if op's mask is constant
     auto emptyMask = rewriter.create<arith::ConstantOp>(
         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -737,10 +753,43 @@ struct ConvertVectorMaskedLoad final
   }
 };
 
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put differently, this method checks whether this would be valid:
+///
+///   vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+///   * vector<4xi4> -> i16 - yes (N = 1)
+///   * vector<4xi4> -> i8 - yes (N = 2)
+///   * vector<3xi4> -> i8 - no (N would have to be 1.5)
+///   * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool isSubByteVecFittable(VectorType subByteVecTy,
+                                 Type multiByteScalarTy) {
+  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+  int elemsPerMultiByte = multiByteBits / subByteBits;
+
+  // TODO: This is a bit too restrictive for vectors rank > 1.
+  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
 
+// TODO: Document-me
 struct ConvertVectorTransferRead final
     : OpConversionPattern<vector::TransferReadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -758,18 +807,20 @@ struct ConvertVectorTransferRead final
     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
-    bool isUnalignedEmulation = origElements % scale != 0;
+    // Note, per-element-alignment was already verified above.
+    bool isFullyAligned =
+        isSubByteVecFittable(op.getVectorType(), newElementType);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -781,20 +832,20 @@ struct ConvertVectorTransferRead final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isUnalignedEmulation
-            ? getConstantIntValue(linearizedInfo.intraDataOffset)
-            : 0;
+        isFullyAligned ? 0
+                       : getConstantIntValue(linearizedInfo.intraDataOffset);
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -802,7 +853,9 @@ struct ConvertVectorTransferRead final
         newPadding);
 
     auto bitCa...
[truncated]

@banach-space banach-space changed the title andrzej/refactor narrow type 4 [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N) Jan 19, 2025
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge with previous, please

@banach-space
Copy link
Contributor Author

Merge with previous, please

Currently as a draft:

GitHub issue for more context: #123630

@banach-space
Copy link
Contributor Author

CC @ziereis

@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_4 branch from 5583441 to 5261216 Compare February 6, 2025 11:06
@banach-space banach-space changed the title [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N) [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N) Feb 6, 2025
@banach-space banach-space requested a review from lialan February 6, 2025 11:08
@banach-space
Copy link
Contributor Author

UPDATE 6/2/25

As all dependencies have been merged, I've re-based this on top of main.

@banach-space
Copy link
Contributor Author

@banach-space banach-space marked this pull request as draft February 9, 2025 16:50
@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_4 branch from 5261216 to 411ee3c Compare February 15, 2025 20:38
@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_4 branch 2 times, most recently from 52bd899 to b3b67eb Compare February 26, 2025 07:50
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

NEXT STEPS (1):

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.

NEXT STEPS (2):

With this PR, I am introducing explicit references to "sub-byte" as
that is effectively what this logic is used of (i.e. for emulating
"sub-byte" types). We should either generalise (which would include
increasing test coverage) or restrict everything to "sub-byte" type
emulation.
@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_4 branch from b3b67eb to 696d7d2 Compare March 1, 2025 19:59
@banach-space banach-space marked this pull request as ready for review March 1, 2025 20:01
@banach-space
Copy link
Contributor Author

UPDATE 1/3/25

All dependencies have been merged, this is now ready for review.

@banach-space
Copy link
Contributor Author

banach-space commented Mar 6, 2025

Ping @lialan , this is the refactor that we discussed in #115922 (this comment: #115922 (comment))

Thanks!

@banach-space banach-space changed the title [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N) [mlir][Vector] Update VectorEmulateNarrowType.cpp Mar 8, 2025
@banach-space banach-space changed the title [mlir][Vector] Update VectorEmulateNarrowType.cpp [mlir][Vector] Refactor VectorEmulateNarrowType.cpp Mar 8, 2025
@banach-space
Copy link
Contributor Author

UPDATE 8/3/25

Simplified the summary. Apologies, originally I packed too much info there.

@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();

// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
bool isFullyAligned =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should also change the name of isFullyAligned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting to rename it everywhere? Note that ATM this name is quite widely used throughout this file:

$ rg "isFullyAligned" mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | wc -l
      14

I am fine renaming it, though I will ask you for a suggestion 😅

Perhaps let me start by explaining the rationale for this name. Basically, most patterns implement this high level logic (see e.g. here):

    // Check per-element alignment.
    if (containerBits % emulatedBits != 0) {
      return rewriter.notifyMatchFailure(
          op, "impossible to pack emulated elements into container elements "
              "(bit-wise misalignment)");
    }
    
    // (...) Some code here
    
    // Note, per-element-alignment was already verified above.
    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;

As a concrete example:

  • vector<3x2xi2> is "per-element" aligned with vector<2xi8> (because we can fit exactly 4 x i2 into i8)
  • vector<3x2xi2> is not full aligned with vector<2xi8> as (3 x 2 = ) 6 % 4 (= 8 / 2) != 0

Let me know if this makes sense. If not, we can iterate :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is indeed hard! I asked my AI and I got isDivisibleInSize out from a bunch of suggestions. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

That's a good name, but captures the "local" condition computed here rather than the "global" one I had in mind. Let me illustrate:

 bool isPerElementAlligned = (containerBits % emulatedBits == 0);
 if (!isPerElementAlligned) {
    return failure();
 }
 
 // (...)
 
 // "local" condition
 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 // "global" condition
 bool isFullyAlligned = isPerElementAlligned && isDivisibleInSize;

So, I see two options:
OPTION 1

// Check per-element allignment
if (containerBits % emulatedBits != 0) {
    return failure();
}

// (...)
// Note, per-element-alignment was already verified above
bool isFullyAligned = (origElements % emulatedPerContainerElem == 0);

OPTION 2

 bool isPerElementAlligned = (containerBits % emulatedBits == 0);
 if (!isPerElementAlligned) {
    return failure();
 }
 
 // (...)
 
 // "local" condition
 bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 // "global" condition
 bool isFullyAlligned = isPerElementAlligned && isDivisibleInSize;

What do you reckon? There's also ...

OPTION 3

// Check per-element allignment
if (containerBits % emulatedBits != 0) {
    return failure();
}

// (...)
// Note, per-element-alignment was already verified above
bool isDivisibleInSize = (origElements % emulatedPerContainerElem == 0);

... but it feels half-way through OPTION 1 and OPTION 2. Not my preference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually prefer option 3..... Once isPerElementAlligned is tested we don't need subsequent references to it.

It also reduces mental burden while reading the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in this commit

return rewriter.notifyMatchFailure(
op, "only src bitwidth of 2 or 4 is supported at this moment");
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no single test for 1 bit :)

Basically, you've contributed i4 emulation and then @ziereis added i2 emulation. We check specifically for i2 and i4 as that's what we've focused on so far.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know some downstream projects try to use 1bit, and I think upstream shouldn't trivially block it in this way. They can contribute i1 tests for sure but overall the code here should support 1-bit scenarios without problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know some downstream projects try to use 1bit, and I think upstream shouldn't trivially block it in this way.

Oh, definitely not trying to block anyone. This is merely trying to document the existing assumptions. Note that this condition is already present:

if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
return rewriter.notifyMatchFailure(
op, "only src bitwidth of 2 or 4 is supported at this moment");

They can contribute i1 tests for sure but overall the code here should support 1-bit scenarios without problem.

They would be welcome with praise and gratitude :)

@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_4 branch from 1892ef2 to df27fff Compare March 10, 2025 16:21
@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_4 branch from df27fff to aa17f5e Compare March 10, 2025 16:26
Copy link
Member

@lialan lialan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's wrap up the remaining unresolved conversations.

@banach-space
Copy link
Contributor Author

LGTM, let's wrap up the remaining unresolved conversations.

Thanks! I assume the the one re isDivisibleInSize was the main remaining bit. That has now been addressed.

Please post a comment if you'd like something else addressed as well. Otherwise I will land this tomorrow.

I really appreciate the review 🙏🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants