Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate DotOp thru Join & improve shmem load into LinearEnc #5924

Merged
merged 12 commits into from
Feb 20, 2025

Conversation

ggengnv
Copy link
Contributor

@ggengnv ggengnv commented Feb 14, 2025

There are two parts to this PR:

  • Propagate dotOp thru join, when dotOp is in the form of linearLayout (mostly reused @lezcano's logic for fp4ToFp)
  • Add rough optimization for shmem -> LL load

Motivation for the second part: currently, shmem load into LL falls back to unswizzled shmem layout in the pipeliner, which results in poor performance.
Not only does the inline_asm > join > reshape path suffer from this, so does fp4_to_fp.
I've added some basic swizzling logic for the shmem layout when loading into dotOp-like LL's. As an example, for bf16xfp4 dot_scaled on a small M, large N/K shape, with fixed config (8, 128, 256), and DISABLE_MMA_V3=1:

  • before this shmem optimization: ~160us
  • after this shmem optimization: ~124us

Similar improvements can be observed for bf16xint4 (with inline_asm).

There's also a small change to increase kWidth in case of join by halving origBitWidth. This should also be important for perf, since otherwise shmem load width would be too small.
 
I believe there's still significant room for improvement for small-M shapes, because shmem -> LL does not yet support ldmatrix. I can look into this next.

PTAL @lezcano @ThomasRaoux, thank you.

@ggengnv ggengnv requested a review from ptillet as a code owner February 14, 2025 08:23
Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comprehensive testing of this machinery in DialectTest.cpp, similar to how we did for Fp4ToFp?

Also, I think the shmem load part looks a bit too specific to me. I'll let @ThomasRaoux figure this one out tho. At any rate, I think it should be forked into a different PR.

@@ -220,7 +220,7 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,

static bool bwdFilter(Operation *op) {
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp,
isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp, JoinOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to move this to a helper function isView. Probably also add SplitOp now.

Comment on lines +256 to +264
if (llvm::any_of(slice, [](Operation *op) { return isa<JoinOp>(op); }))
origBitWidth /= 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we might want to do something like trying a large kWidth, run layout backpropagation and see what's the contiguity that you get at the loads that feed into it. This is tricky to do right now I think, but could you leave a comment noting this?

Comment on lines 2844 to 2847
auto ll = toLinearLayout(shape, srcEnc);
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
dstShape.push_back(1);
ll = reshapeHelper(ctx, ll, dstShape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toLinearLayout always returns a layout with dims dim0...dimn-1 (feel free to assert it), so this reshapeHelper can probably just be reshapeOuts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be wrong (was confused about this), but it seemed that when I did reshapeOuts earlier, the new dim would come out as the most major dim. I copied this logic from inferReshapeOp so that the dim is the most minor instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that you are just adding a dimension of size one, this dimension does not need to be the most minor or not. It doesn't matter as it will not move (has size 1).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this become size 2 as it's multiplied by 2 in tryJoinOnAxis?

I was also seeing this "failed to infer return types", I think due to incompatible encodings between inferReturnsTypes (from this function) and the encoding from propagating backwards from reshape. The former would be contiguous along the most major dim, and the latter the most minor dim.

Copy link
Contributor Author

@ggengnv ggengnv Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I just tried what you suggested for this case and there wasn't an error anymore, so tranpose isn't necessary here as you said. the error I saw must've been due to something else?

I thought it might be more concise here anyway to invoke an existing helper for creating new outDimNames even though there's some unnecessary tranposes. I can forgo the helper and leave some redundant code here, if you think clarity is more important?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be better to simply call LinearLayout::reshapeOuts with one more dim, yes.
If you want, you can even implement LinearLayout::unsqueeze, similar to torch.Tensor.unsqueeze to be more concise.

Comment on lines 2852 to 2856
auto result =
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc);

if (!result.succeeded())
return result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This join will always succeed, so better to assert that this is indeed the case.

Comment on lines 549 to 551
SmallVector<unsigned int> getContig(const LinearEncodingAttr &enc,
const char *inDim,
SmallVector<unsigned int> lowerContig) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind moving this next to all the other LL methods and make it a LinearEncodingAttr method?

@lezcano
Copy link
Contributor

lezcano commented Feb 14, 2025

Regarding generalising the use of ldmatrix, it's in my list of things to do in the not-so-distant future.

@@ -218,10 +218,14 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
}
}

static bool isView(Operation *op) {
return isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
ConvertLayoutOp>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvertLayoutOp may not be a view op unless it's the identity

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 convert_layout is not a view, broadcast is not a view either

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast is a view after layout propagation tho?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I see that it may duplicate some elements internally I guess

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some minor comments but it looks good to me overall.
I think the shmem heuristic should be acceptable. We do have plans to move the shmem layout decision later in the compiler pipeline, hopefully we can improve things at that time.

// Append dim to shape
auto ll = toLinearLayout(shape, srcEnc);
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
dstShape.push_back(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we appending 1 and not 2 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LL is doubled along an existing axis in tryJoinOnAxis() (Mario's logic that I refactored out). Before calling that, here I'm just "unsqueezing" the shape to be consistent with Join's API - adding a new dimension to operate on.

Comment on lines 2891 to 2894
auto axis = shape.size() - 1;
if (shape[axis] != 2)
return emitOptionalError(
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
"and CTAsPerCGA = 1 for the last dimension of the input");
loc, "SplitOp input shape should have 2 in the last dim");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be an assert I think

@@ -218,10 +218,14 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
}
}

static bool isView(Operation *op) {
return isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp,
ConvertLayoutOp>(op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 convert_layout is not a view, broadcast is not a view either

Comment on lines +254 to +264
// If JoinOp occurred at least once, in backward layout propagation,
// the kWidth will be split in half as we pass through the JoinOp.
// Hence we divide origBitWidth by 2 here to compensate for that and
// improve our load width.
// This won't be optimal if there is a tree of multiple JoinOps, which
// would require counting the max number of JoinOp's along any path.
//
// In the future we might want to do something like trying a large kWidth,
// run layout backpropagation and see what's the contiguity that you
// get at the loads that feed into it.
if (llvm::any_of(slice, [](Operation *op) { return isa<JoinOp>(op); }))
origBitWidth /= 2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems quite hacky, do we really need this? Otherwise maybe we should have a lit test for it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's hacky. I added it here since I saw the similar heuristic above for fp4ToFp

  if (llvm::any_of(slice, [](Operation *op) { return isa<Fp4ToFpOp>(op); }))
    return 4;

for perf I think this is quite important, since if kWidth is too small (e.g. 4 instead of 8 for 4-bit inputs), the load width will be < 32b and so will be very inefficient. and for sure I can add a lit test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If in the future we want to push more mixed dots we can implement something like what's outlined in the comment, but I think this is fine for now.

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two minor comments otherwise looks good!

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but could you move getStandardOutDimNames to LayoutUtils.cpp?
After that, feel free to merge.
Thank you for the changes!

@ggengnv
Copy link
Contributor Author

ggengnv commented Feb 19, 2025

could someone approve the CI please? thanks

@ggengnv
Copy link
Contributor Author

ggengnv commented Feb 19, 2025

@lezcano had to disable the shmem swizzling optimization where global blocked order != linear order, otherwise cp.async will be <4B.

The above happens because OptimizeDotOperands::SwizzleShmemConvert doesn't seem to apply anymore, due to how LL's propagate thru TransOps. I added a TODO for that.

@ggengnv
Copy link
Contributor Author

ggengnv commented Feb 20, 2025

hmm I don't have merge commits and the workflow passed but I still see

Merging is blocked
Merge is not an allowed merge method in this repository.
This branch must not contain merge commits.

I tried rebasing on main but the message is still here.

@lezcano lezcano enabled auto-merge (squash) February 20, 2025 17:37
@lezcano lezcano merged commit 4f30282 into triton-lang:main Feb 20, 2025
7 checks passed
ThomasRaoux added a commit to ThomasRaoux/triton that referenced this pull request Feb 21, 2025
The changed heuristic to pick the swizzling causes
performance regression in some cases. Reverting it for now.
@ThomasRaoux
Copy link
Collaborator

@ggengnv, I'm reverting the smem layout heuristic change as it causes performance degradation in our mixed mode kernels: #5983

@ggengnv
Copy link
Contributor Author

ggengnv commented Feb 21, 2025

@ThomasRaoux I see. I can evaluate the logic locally and see if I can improve it.

ThomasRaoux added a commit that referenced this pull request Feb 21, 2025
Partial revert of #5924

The changed heuristic to pick the swizzling causes performance
regression in some cases. Reverting it for now.

cc: @ggengnv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants