-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Propagate DotOp thru Join & improve shmem load into LinearEnc #5924
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add comprehensive testing of this machinery in DialectTest.cpp, similar to how we did for Fp4ToFp?
Also, I think the shmem load part looks a bit too specific to me. I'll let @ThomasRaoux figure this one out tho. At any rate, I think it should be forked into a different PR.
@@ -220,7 +220,7 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version, | |||
|
|||
static bool bwdFilter(Operation *op) { | |||
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) || | |||
isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp, | |||
isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, Fp4ToFpOp, JoinOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to move this to a helper function isView
. Probably also add SplitOp
now.
if (llvm::any_of(slice, [](Operation *op) { return isa<JoinOp>(op); })) | ||
origBitWidth /= 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we might want to do something like trying a large kWidth
, run layout backpropagation and see what's the contiguity that you get at the loads that feed into it. This is tricky to do right now I think, but could you leave a comment noting this?
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
auto ll = toLinearLayout(shape, srcEnc); | ||
SmallVector<int64_t> dstShape(shape.begin(), shape.end()); | ||
dstShape.push_back(1); | ||
ll = reshapeHelper(ctx, ll, dstShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
toLinearLayout
always returns a layout with dims dim0...dimn-1 (feel free to assert it), so this reshapeHelper
can probably just be reshapeOuts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could be wrong (was confused about this), but it seemed that when I did reshapeOuts
earlier, the new dim would come out as the most major dim. I copied this logic from inferReshapeOp
so that the dim is the most minor instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that you are just adding a dimension of size one, this dimension does not need to be the most minor or not. It doesn't matter as it will not move (has size 1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this become size 2 as it's multiplied by 2 in tryJoinOnAxis
?
I was also seeing this "failed to infer return types", I think due to incompatible encodings between inferReturnsTypes (from this function) and the encoding from propagating backwards from reshape. The former would be contiguous along the most major dim, and the latter the most minor dim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, I just tried what you suggested for this case and there wasn't an error anymore, so tranpose isn't necessary here as you said. the error I saw must've been due to something else?
I thought it might be more concise here anyway to invoke an existing helper for creating new outDimNames even though there's some unnecessary tranposes. I can forgo the helper and leave some redundant code here, if you think clarity is more important?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be better to simply call LinearLayout::reshapeOuts
with one more dim, yes.
If you want, you can even implement LinearLayout::unsqueeze
, similar to torch.Tensor.unsqueeze
to be more concise.
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
auto result = | ||
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc); | ||
|
||
if (!result.succeeded()) | ||
return result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This join will always succeed, so better to assert that this is indeed the case.
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
SmallVector<unsigned int> getContig(const LinearEncodingAttr &enc, | ||
const char *inDim, | ||
SmallVector<unsigned int> lowerContig) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind moving this next to all the other LL methods and make it a LinearEncodingAttr method?
Regarding generalising the use of |
@@ -218,10 +218,14 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version, | |||
} | |||
} | |||
|
|||
static bool isView(Operation *op) { | |||
return isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp, | |||
ConvertLayoutOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConvertLayoutOp may not be a view op unless it's the identity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 convert_layout is not a view, broadcast is not a view either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broadcast is a view after layout propagation tho?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I see that it may duplicate some elements internally I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some minor comments but it looks good to me overall.
I think the shmem heuristic should be acceptable. We do have plans to move the shmem layout decision later in the compiler pipeline, hopefully we can improve things at that time.
// Append dim to shape | ||
auto ll = toLinearLayout(shape, srcEnc); | ||
SmallVector<int64_t> dstShape(shape.begin(), shape.end()); | ||
dstShape.push_back(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we appending 1 and not 2 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LL is doubled along an existing axis in tryJoinOnAxis()
(Mario's logic that I refactored out). Before calling that, here I'm just "unsqueezing" the shape to be consistent with Join's API - adding a new dimension to operate on.
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
auto axis = shape.size() - 1; | ||
if (shape[axis] != 2) | ||
return emitOptionalError( | ||
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, " | ||
"and CTAsPerCGA = 1 for the last dimension of the input"); | ||
loc, "SplitOp input shape should have 2 in the last dim"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should be an assert I think
@@ -218,10 +218,14 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version, | |||
} | |||
} | |||
|
|||
static bool isView(Operation *op) { | |||
return isa<BroadcastOp, ExpandDimsOp, ReshapeOp, TransOp, JoinOp, SplitOp, | |||
ConvertLayoutOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 convert_layout is not a view, broadcast is not a view either
// If JoinOp occurred at least once, in backward layout propagation, | ||
// the kWidth will be split in half as we pass through the JoinOp. | ||
// Hence we divide origBitWidth by 2 here to compensate for that and | ||
// improve our load width. | ||
// This won't be optimal if there is a tree of multiple JoinOps, which | ||
// would require counting the max number of JoinOp's along any path. | ||
// | ||
// In the future we might want to do something like trying a large kWidth, | ||
// run layout backpropagation and see what's the contiguity that you | ||
// get at the loads that feed into it. | ||
if (llvm::any_of(slice, [](Operation *op) { return isa<JoinOp>(op); })) | ||
origBitWidth /= 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems quite hacky, do we really need this? Otherwise maybe we should have a lit test for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it's hacky. I added it here since I saw the similar heuristic above for fp4ToFp
if (llvm::any_of(slice, [](Operation *op) { return isa<Fp4ToFpOp>(op); }))
return 4;
for perf I think this is quite important, since if kWidth is too small (e.g. 4 instead of 8 for 4-bit inputs), the load width will be < 32b and so will be very inefficient. and for sure I can add a lit test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If in the future we want to push more mixed dots we can implement something like what's outlined in the comment, but I think this is fine for now.
cd08f60
to
79b8b47
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two minor comments otherwise looks good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but could you move getStandardOutDimNames
to LayoutUtils.cpp
?
After that, feel free to merge.
Thank you for the changes!
717cc1f
to
e38323b
Compare
could someone approve the CI please? thanks |
6fd5b6f
to
40f1b25
Compare
@lezcano had to disable the shmem swizzling optimization where global blocked order != linear order, otherwise cp.async will be <4B. The above happens because OptimizeDotOperands::SwizzleShmemConvert doesn't seem to apply anymore, due to how LL's propagate thru TransOps. I added a TODO for that. |
40f1b25
to
18ca2e3
Compare
18ca2e3
to
b36c2a2
Compare
hmm I don't have merge commits and the workflow passed but I still see
I tried rebasing on main but the message is still here. |
The changed heuristic to pick the swizzling causes performance regression in some cases. Reverting it for now.
@ThomasRaoux I see. I can evaluate the logic locally and see if I can improve it. |
There are two parts to this PR:
Motivation for the second part: currently, shmem load into LL falls back to unswizzled shmem layout in the pipeliner, which results in poor performance.
Not only does the
inline_asm
>join
>reshape
path suffer from this, so doesfp4_to_fp
.I've added some basic swizzling logic for the shmem layout when loading into dotOp-like LL's. As an example, for bf16xfp4
dot_scaled
on a small M, large N/K shape, with fixed config (8, 128, 256), andDISABLE_MMA_V3=1
:Similar improvements can be observed for bf16xint4 (with inline_asm).
There's also a small change to increase kWidth in case of
join
by halvingorigBitWidth
. This should also be important for perf, since otherwise shmem load width would be too small.I believe there's still significant room for improvement for small-M shapes, because shmem -> LL does not yet support
ldmatrix
. I can look into this next.PTAL @lezcano @ThomasRaoux, thank you.