Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) #123527

Merged
merged 2 commits into from
Feb 6, 2025

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jan 19, 2025

This is PR 2 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

CHANGE 1

Renames the variable "scale". Note, "scale" could mean either:

  • "container-elements-per-emulated-type", or
  • "emulated-elements-per-container-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually i8),
this PR reduces the cognitive load by making this clear.

CHANGE 2

Replaces isUnalignedEmulation with isFullyAligned

Note, isUnalignedEmulation is always computed following a
"per-element-alignment" condition:

// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
  return rewriter.notifyMatchFailure(
    op, "impossible to pack emulated elements into container elements "
    "(bit-wise misalignment)");
}

// (...)

bool isUnalignedEmulation = origElements % emulatedPerContainerElem != 0;

Given that isUnalignedEmulation captures only one of two conditions
required for "full alignment", it should be re-named as
isPartiallyUnalignedEmulation. Instead, I've flipped the condition and
renamed it as isFullyAligned:

bool isFullyAligned = origElements % emulatedPerContainerElem == 0;

CHANGE 3

  • Unifies various comments throughout the file (for consistency).
  • Adds new comments throughout the file and adds TODOs where high-level
    comments are missing.

GitHub issue to track this work: #123630

@llvmbot
Copy link
Member

llvmbot commented Jan 19, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N)
  • [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N)

Full diff: https://github.com/llvm/llvm-project/pull/123527.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+74-63)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..4e0be258954496 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
                    OpFoldResult linearizedIndices,
                    int64_t numEmultedElementsToLoad, Type origElemType,
                    Type emulatedElemType) {
-  auto scale = emulatedElemType.getIntOrFloatBitWidth() /
-               origElemType.getIntOrFloatBitWidth();
+  auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
+                                  origElemType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
       loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
       getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
-      loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
+      loc,
+      VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
+                      origElemType),
       newLoad);
 }
 
@@ -314,14 +316,14 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -337,7 +339,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector<4xi8>
 
     auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -346,13 +348,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     OpFoldResult linearizedIndices;
     std::tie(std::ignore, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
+    auto numElements = origElements / elementsPerContainerType;
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc, VectorType::get(numElements, newElementType),
         op.getValueToStore());
@@ -385,17 +387,17 @@ struct ConvertVectorMaskedStore final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
 
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
     int origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
+    if (origElements % elementsPerContainerType != 0)
       return failure();
 
     auto stridedMetadata =
@@ -404,7 +406,7 @@ struct ConvertVectorMaskedStore final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndicesOfr) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -444,12 +446,13 @@ struct ConvertVectorMaskedStore final
     //
     // FIXME: Make an example based on the comment above work (see #115460 for
     // reproducer).
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, elementsPerContainerType);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements = (origElements + elementsPerContainerType - 1) /
+                       elementsPerContainerType;
     auto newType = VectorType::get(numElements, newElementType);
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
@@ -458,7 +461,8 @@ struct ConvertVectorMaskedStore final
         loc, newType, adaptor.getBase(), linearizedIndices,
         newMask.value()->getResult(0), passThru);
 
-    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitCastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -493,14 +497,14 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -532,7 +536,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -541,7 +545,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -553,9 +557,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             : 0;
 
     // Always load enough elements which can cover the original elements.
-    int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxintraDataOffset + origElements, scale);
+    int64_t maxintraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
+                                        elementsPerContainerType);
     Value result =
         emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
                            numElements, oldElementType, newElementType);
@@ -596,14 +601,14 @@ struct ConvertVectorMaskedLoad final
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
@@ -649,7 +654,7 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -657,7 +662,7 @@ struct ConvertVectorMaskedLoad final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -668,18 +673,21 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    FailureOr<Operation *> newMask = getCompressedMaskOp(
-        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    FailureOr<Operation *> newMask =
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
+                            elementsPerContainerType, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
     Value passthru = op.getPassThru();
 
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
     auto loadType = VectorType::get(numElements, newElementType);
-    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+    auto newBitcastType =
+        VectorType::get(numElements * elementsPerContainerType, oldElementType);
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -706,8 +714,8 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    auto newSelectMaskType =
-        VectorType::get(numElements * scale, rewriter.getI1Type());
+    auto newSelectMaskType = VectorType::get(
+        numElements * elementsPerContainerType, rewriter.getI1Type());
     // TODO: try to fold if op's mask is constant
     auto emptyMask = rewriter.create<arith::ConstantOp>(
         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
@@ -758,18 +766,18 @@ struct ConvertVectorTransferRead final
     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
     Type oldElementType = op.getType().getElementType();
     Type newElementType = convertedType.getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = newElementType.getIntOrFloatBitWidth();
+    int oldBits = oldElementType.getIntOrFloatBitWidth();
+    int newBits = newElementType.getIntOrFloatBitWidth();
 
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
+    // Check per-element alignment.
+    if (newBits % oldBits != 0) {
+      return rewriter.notifyMatchFailure(op, "unalagined element types");
     }
-    int scale = dstBits / srcBits;
+    int elementsPerContainerType = newBits / oldBits;
 
     auto origElements = op.getVectorType().getNumElements();
 
-    bool isUnalignedEmulation = origElements % scale != 0;
+    bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
                                                       adaptor.getPadding());
@@ -781,7 +789,7 @@ struct ConvertVectorTransferRead final
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
-            rewriter, loc, srcBits, dstBits,
+            rewriter, loc, oldBits, newBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
@@ -792,9 +800,10 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
-    auto numElements =
-        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
+    int64_t maxIntraDataOffset =
+        foldedIntraVectorOffset.value_or(elementsPerContainerType - 1);
+    auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
+                                        elementsPerContainerType);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -802,7 +811,9 @@ struct ConvertVectorTransferRead final
         newPadding);
 
     auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newRead);
+        loc,
+        VectorType::get(numElements * elementsPerContainerType, oldElementType),
+        newRead);
 
     Value result = bitCast->getResult(0);
     if (!foldedIntraVectorOffset) {

@banach-space banach-space changed the title andrzej/refactor narrow type 2 [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) Jan 19, 2025
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is continuing the renaming, right? Could we merge with the previous PR? It's difficult to see what the final picture is

@lialan
Copy link
Member

lialan commented Jan 20, 2025

May I propose we add a comment section at the beginning of the code that explains the naming conventions? this will eliminate most of the future naming issues and avoids back-and-forth reviews about names.

@banach-space
Copy link
Contributor Author

May I propose we add a comment section at the beginning of the code that explains the naming conventions? this will eliminate most of the future naming issues and avoids back-and-forth reviews about names.

Happy to. Added as Proposal 3:

I'd like to finish Proposal 2 and Proposal 3 first (following which the naming should be clearer).

@banach-space
Copy link
Contributor Author

Could we merge with the previous PR?

Currently as a draft:

It's difficult to see what the final picture is

Trying to capture it here:

Copy link
Member

@lialan lialan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some suggestions on the naming.

@@ -282,13 +282,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
OpFoldResult linearizedIndices,
int64_t numEmultedElementsToLoad, Type origElemType,
Type emulatedElemType) {
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
origElemType.getIntOrFloatBitWidth();
auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() /
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elemPerContainerType to make it a little bit shorter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I'd like to introduce something even less ambiguous, so let me propose:

  • emulatedPerContainerElem.

I'll send an update shortly and you can tell me what you think :)

auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
loc,
VectorType::get(numEmultedElementsToLoad * elementsPerContainerType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in numEmultedElementsToLoad ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This an other typos in emulatedVectorLoad are fixed in:

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits


if (dstBits % srcBits != 0) {
Copy link
Member

@pashu123 pashu123 Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming of srcBits/destBits and oldBits/newBits is quite confusing. Could we use a more descriptive term, like emulatedBits, instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#123526 ;-)

}
int scale = dstBits / srcBits;
int elementsPerContainerType = newBits / oldBits;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it elementsPerContainer? It doesn't seem like type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just switched to emulatedPerContainerElem - WDYT?

op, "only dstBits % srcBits == 0 supported");
// Check per-element alignment.
if (newBits % oldBits != 0) {
return rewriter.notifyMatchFailure(op, "unalagined element types");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return rewriter.notifyMatchFailure(op, "unalagined element types");
return rewriter.notifyMatchFailure(op, "unaligned element types");

@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_2 branch from d40b31b to 593df77 Compare February 2, 2025 15:39
@banach-space
Copy link
Contributor Author

@pashu123 , thanks for the review. Looks you have reviewed changes from one of the dependencies of this PR:

:) Note that that has been merged and those changes are no longer present in this PR (and all of your suggestions have been incorporated 🙏🏻 ).

I always list dependencies in the summary - suggestions on how to make this clearer are very welcome. Perhaps it's time to try:

Copy link

github-actions bot commented Feb 2, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_2 branch 2 times, most recently from 9fab1bb to a3dfd91 Compare February 2, 2025 15:56
This is PR 2 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

This PR renames the variable "scale". Note, "scale" could mean either:

  * "original-elements-per-emulated-type", or
  * "emulated-elements-per-original-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.

**DEPENDS ON:**
* llvm#123526 123526

Please only review the [top
commit](llvm@d40b31b).

**GitHub issue to track this work**:
llvm#123630
@banach-space banach-space force-pushed the andrzej/refactor_narrow_type_2 branch from a3dfd91 to aaeb0fb Compare February 3, 2025 10:21
@banach-space
Copy link
Contributor Author

UPDATE (3/2/24): All dependencies of this PR have been merged + rebased on top of main.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvements! LGTM. Let's try to concentrate all the renaming, typo fixing, etc. to this file within a single PR. It makes hard to get a clear picture of the final state.

@banach-space
Copy link
Contributor Author

Let's try to concentrate all the renaming, typo fixing, etc. to this file within a single PR. It makes hard to get a clear picture of the final state.

IIUC, you are suggesting that I include PATCH 3/N in this PR? That's fine with me, but I'd also like to hear from @lialan and @pashu123 as two other active reviewers.

Personally, I prefer smaller, isolated changes (IMHO, GitHub UI is pretty bad for bigger changes). However, ultimately, my priority is to make this easy to review.

So, kind request for further feedback:

  • 🚢 - merge this as is, i.e. ship it! (*)
  • 👍🏻 - I am OK with PATCH 3/N being included here (and no need for further review is required).

(*) i.e. leave PATCH 3/N within #123528

@hanhanW
Copy link
Contributor

hanhanW commented Feb 4, 2025

I'm +1 on having them in a single PR because the change is not big, and they all touch the same file. Having a clear PR description is good enough to me. IMHO, it also makes the codebase state and commits tracking easier. (2) + (3) is like +107 -78 lines, which is a reasonable size of PR to me.

@lialan
Copy link
Member

lialan commented Feb 4, 2025

I'm +1 on having them in a single PR because the change is not big, and they all touch the same file. Having a clear PR description is good enough to me. IMHO, it also makes the codebase state and commits tracking easier. (2) + (3) is like +107 -78 lines, which is a reasonable size of PR to me.

+1 too.

This is PR 3 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Replaces `isUnalignedEmulation` with `isFullyAligned`

Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
  return rewriter.notifyMatchFailure(
    op, "impossible to pack emulated elements into container elements "
    "(bit-wise misalignment)");
}

// (...)

bool isUnalignedEmulation = origElements % emulatedPerContainerElem != 0;
```

Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:

```cpp
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
```

2. In addition:
  * Unifies various comments throughout the file (for consistency).
  * Adds new comments throughout the file and adds TODOs where high-level
    comments are missing.
@banach-space
Copy link
Contributor Author

Thanks for the feedback! Merged #123528 into this PR.

@dcaballe
Copy link
Contributor

dcaballe commented Feb 4, 2025

Personally, I prefer smaller, isolated changes

Yeah, me too... Doing this right is kind of an art, I guess :). I think the key point is to find the right trade-off. I would say "small isolated changes" as long as they can be tested. For an NFC (i.e., var renaming, API changes, refactoring, typos, formatting, ...) I would go with a single PR, even if it's large. For this kind of changes, it's good for the reviewer to see the overall final state and add feedback according to that. Multiple patches add more overhead/cognitive load in trying to compose all the pieces together... Well, at least, that's how I feel about it. People may think differently, of course.

@banach-space
Copy link
Contributor Author

I will wait ~24hrs. If there's no new comments, I will assume that folks are happy with these changes and merge it.

@banach-space banach-space merged commit 78f690b into llvm:main Feb 6, 2025
8 checks passed
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 9, 2025
1. Documents `ConvertVectorStore`.

2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to
   `isFullyAligned` and `numSrcElemsPerDest` to
   `emulatedPerContainerElem`.
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
This is PR 2 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

**CHANGE 1** 

Renames the variable "scale". Note, "scale" could mean either:

  * "container-elements-per-emulated-type", or
  * "emulated-elements-per-container-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.

**CHANGE 2** 

Replaces `isUnalignedEmulation` with `isFullyAligned`

Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
  return rewriter.notifyMatchFailure(
    op, "impossible to pack emulated elements into container elements "
    "(bit-wise misalignment)");
}

// (...)

bool isUnalignedEmulation = origElements % emulatedPerContainerElem != 0;
```

Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:

```cpp
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
```

**CHANGE 3**
  * Unifies various comments throughout the file (for consistency).
* Adds new comments throughout the file and adds TODOs where high-level
    comments are missing.
    
    
**GitHub issue to track this work**:
llvm#123630
banach-space added a commit that referenced this pull request Feb 15, 2025
…126422)

1. Documents `ConvertVectorStore`. As the generated output is rather complex, I
  have refined the comments + variable names in:
    * "vector-emulate-narrow-type-unaligned-non-atomic.mlir",
  to serve as reference for this pattern.

2. As a follow-on for #123527, renames `isAlignedEmulation` to `isFullyAligned`
  and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
@banach-space banach-space deleted the andrzej/refactor_narrow_type_2 branch February 15, 2025 20:17
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…lvm#126422)

1. Documents `ConvertVectorStore`. As the generated output is rather complex, I
  have refined the comments + variable names in:
    * "vector-emulate-narrow-type-unaligned-non-atomic.mlir",
  to serve as reference for this pattern.

2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned`
  and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
YutongZhuu pushed a commit to YutongZhuu/llvm-project that referenced this pull request Mar 8, 2025
…lvm#126422)

1. Documents `ConvertVectorStore`. As the generated output is rather complex, I
  have refined the comments + variable names in:
    * "vector-emulate-narrow-type-unaligned-non-atomic.mlir",
  to serve as reference for this pattern.

2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned`
  and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
YutongZhuu pushed a commit to YutongZhuu/llvm-project that referenced this pull request Mar 8, 2025
…lvm#126422)

1. Documents `ConvertVectorStore`. As the generated output is rather complex, I
  have refined the comments + variable names in:
    * "vector-emulate-narrow-type-unaligned-non-atomic.mlir",
  to serve as reference for this pattern.

2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned`
  and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants