Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate DotOp thru Join & improve shmem load into LinearEnc #5924

Merged
merged 12 commits into from
Feb 20, 2025
2 changes: 2 additions & 0 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ class DialectInferLayoutInterface

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
ArrayRef<int64_t> shape,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
ArrayRef<int64_t> shape,
std::optional<Location> loc) const = 0;

// Verify that the encoding are compatible to be used together in a dot
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape) const;

SmallVector<unsigned int> getContig(const char *, SmallVector<unsigned int>) const;
SmallVector<unsigned> getContigPerThread() const;
SmallVector<unsigned> getContigPerWarp() const;
SmallVector<unsigned> getOrder() const;

// Generalizes get{Warp,Thread,CTA}Order to linear layouts.
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ unsigned
getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis);

// Returns whether the op is a "view op", i.e. doesn't move any data
bool isView(Operation *op);

/* Dump Triton IR in graphviz dot format.
*
* You can override `onValue` and `onOperation` in a subclass to mark
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Tools/LayoutUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ LinearLayout ensureLayoutNotSmallerThan(
// are "dim0", "dim1", etc.
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);

// Return a vector of the standard out dimension name/value pairs, i.e.
// ("dim0", dstShape[0]), ("dim1", dstShape[1]), etc.
SmallVector<std::pair<StringAttr, int32_t>>
standardOutDimPairs(MLIRContext *ctx, ArrayRef<int64_t> dstShape);

// Return an identity mapping from `inDimName` to the standard out dimensions,
// with the dimensions sized according to the shape. The bases are sorted
// according to `order`, with the most minor dimension first.
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
Attribute retEnc;
if (srcEnc) {
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferJoinOpEncoding(srcEnc, retEnc, location)
->inferJoinOpEncoding(srcEnc, retEnc, srcTy.getShape(), location)
.failed()) {
return failure();
}
Expand Down Expand Up @@ -1079,7 +1079,7 @@ LogicalResult SplitOp::inferReturnTypes(
Attribute retEnc;
if (srcEnc) {
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferSplitOpEncoding(srcEnc, retEnc, location)
->inferSplitOpEncoding(srcEnc, retEnc, srcTy.getShape(), location)
.failed()) {
return failure();
}
Expand Down
269 changes: 161 additions & 108 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,42 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
return encoding;
}

LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl,
LinearLayout &outLl, bool fwdInference, int axis,
std::optional<Location> loc) {
auto kRegister = StringAttr::get(ctx, "register");
auto outDims = llvm::to_vector(inLl.getOutDimNames());
if (fwdInference) {
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
outLl = split * inLl;
} else {
// TODO This requires a division algorithm!
// Implement manually ll.divideLeft(split)
auto contiguousElems =
LinearEncodingAttr::get(ctx, inLl).getContigPerThread();
if (contiguousElems[axis] > 1) {
LinearLayout::BasesT newBases;
for (const auto &basesDim : inLl.getBases()) {
std::vector<std::vector<int32_t>> newBasesDim;
for (auto base : basesDim.second) {
if (base[axis] == 1) {
continue;
}
base[axis] /= 2;
newBasesDim.push_back(std::move(base));
}
newBases.insert({basesDim.first, std::move(newBasesDim)});
}
outLl = LinearLayout(std::move(newBases), std::move(outDims));
} else {
return emitOptionalError(loc,
"Fp4ToFpOp/SplitOp requires at least 2 elements "
"per thread in the axis/last dimension");
}
}
return success();
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down Expand Up @@ -1239,28 +1275,39 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
return scaledLayout.basesPerDim(kRegister, /*skipBroadcast=*/false);
}

SmallVector<unsigned> LinearEncodingAttr::getContigPerThread() const {
SmallVector<unsigned>
LinearEncodingAttr::getContig(const char *inDim,
SmallVector<unsigned int> lowerContig) const {
auto ll = getLinearLayout();
const auto &regs =
ll.getBases().find(StringAttr::get(getContext(), "register"))->second;
const auto &bases =
ll.getBases().find(StringAttr::get(getContext(), inDim))->second;
auto order = getOrder();
auto rank = order.size();

SmallVector<unsigned> contig(rank, 1);
auto regIt = regs.begin();
SmallVector<unsigned> contig(lowerContig);
auto basisIt = bases.begin();
for (unsigned dim : order) {
std::vector<int32_t> basis(rank, 0);
basis[dim] = 1;
basis[dim] = contig[dim];

while (regIt != regs.end() && *regIt == basis) {
while (basisIt != bases.end() && *basisIt == basis) {
contig[dim] *= 2;
basis[dim] *= 2;
++regIt;
++basisIt;
}
}
return contig;
}

SmallVector<unsigned> LinearEncodingAttr::getContigPerThread() const {
SmallVector<unsigned> contig(getOrder().size(), 1);
return getContig("register", contig);
}

SmallVector<unsigned> LinearEncodingAttr::getContigPerWarp() const {
return getContig("lane", getContigPerThread());
}

unsigned
LinearEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape) const {
return product(getElemsPerThread(shape));
Expand Down Expand Up @@ -2721,14 +2768,12 @@ struct TritonGPUInferLayoutInterface
}

auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
auto srcOutDims = to_vector(src.getOutDimNames());

auto newOutDims = standardOutDimPairs(ctx, dstShape);

// reshapeOp assumes minor-to-major, so we need to transpose the out dims
// before the reshape
auto srcOutDims = to_vector(src.getOutDimNames());
std::reverse(srcOutDims.begin(), srcOutDims.end());
std::reverse(newOutDims.begin(), newOutDims.end());
auto dst = src.transposeOuts(srcOutDims)
Expand All @@ -2740,82 +2785,117 @@ struct TritonGPUInferLayoutInterface

LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
ArrayRef<int64_t> shape,
std::optional<Location> loc) const override {
auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!enc) {
return emitOptionalError(loc,
"JoinOp can only operate on BlockedEncoding");
if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
// JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
// AxBxCx2. The encoding is the same as the input, but with 2 elems per
// thread in the new dimension. The new dimension is most-minor.
auto append = [](ArrayRef<unsigned> vals, int val) {
SmallVector<unsigned> ret(vals);
ret.push_back(val);
return ret;
};
auto appendMinorDim = [](ArrayRef<unsigned> order) {
SmallVector<unsigned> ret(order);
ret.insert(ret.begin(), ret.size());
return ret;
};
dstEnc = BlockedEncodingAttr::get(
enc.getContext(), //
append(enc.getSizePerThread(), 2), //
append(enc.getThreadsPerWarp(), 1), //
append(enc.getWarpsPerCTA(), 1), //
appendMinorDim(enc.getOrder()), //
CTALayoutAttr::get(enc.getContext(), //
append(enc.getCTAsPerCGA(), 1),
append(enc.getCTASplitNum(), 1),
appendMinorDim(enc.getCTAOrder())));
return success();
}

// JoinOp takes two tensors of shape AxBxC and generates a tensor of shape
// AxBxCx2. The encoding is the same as the input, but with 2 elems per
// thread in the new dimension. The new dimension is most-minor.
auto append = [](ArrayRef<unsigned> vals, int val) {
SmallVector<unsigned> ret(vals);
ret.push_back(val);
return ret;
};
auto appendMinorDim = [](ArrayRef<unsigned> order) {
SmallVector<unsigned> ret(order);
ret.insert(ret.begin(), ret.size());
return ret;
};
dstEnc = BlockedEncodingAttr::get(
enc.getContext(), //
append(enc.getSizePerThread(), 2), //
append(enc.getThreadsPerWarp(), 1), //
append(enc.getWarpsPerCTA(), 1), //
appendMinorDim(enc.getOrder()), //
CTALayoutAttr::get(enc.getContext(), //
append(enc.getCTAsPerCGA(), 1),
append(enc.getCTASplitNum(), 1),
appendMinorDim(enc.getCTAOrder())));
auto ctx = getContext();

// Append dim to shape
auto ll = toLinearLayout(shape, srcEnc);
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
dstShape.push_back(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we appending 1 and not 2 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LL is doubled along an existing axis in tryJoinOnAxis() (Mario's logic that I refactored out). Before calling that, here I'm just "unsqueezing" the shape to be consistent with Join's API - adding a new dimension to operate on.

ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape));

// Try join on last dim
auto axis = dstShape.size() - 1;
auto newLl = LinearLayout::empty();
auto result =
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/true, axis, loc);

assert(result.succeeded());
dstEnc = LinearEncodingAttr::get(ctx, newLl);
return success();
}

LogicalResult
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
ArrayRef<int64_t> shape,
std::optional<Location> loc) const override {
auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!enc) {
return emitOptionalError(loc,
"SplitOp can only operate on BlockedEncoding");
if (enc) {
// SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
// shape AxBxC. The input must have 2 elements per thread in the last
// dimension, which must be most-minor. The result encoding is the same
// as the input, but with the last dimension removed.
if (enc.getSizePerThread().back() != 2) {
return emitOptionalError(
loc, "SplitOp requires 2 elements per thread in the "
"last dimension of the input");
}
if (enc.getThreadsPerWarp().back() != 1 ||
enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) {
return emitOptionalError(
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
"and CTAsPerCGA = 1 for the last dimension of the input");
}
if (enc.getCTALayout().getCTAsPerCGA().back() != 1) {
return emitOptionalError(
loc,
"SplitOp requires the last dimension to be most-minor in CTAOrder");
}
SmallVector<unsigned> newOrder(enc.getOrder());
int splitDim = newOrder.size() - 1;
// Remove splitDim from order.
newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim),
newOrder.end());
dstEnc = BlockedEncodingAttr::get(
enc.getContext(), //
ArrayRef(enc.getSizePerThread()).drop_back(1),
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
CTALayoutAttr::get(enc.getContext(), //
ArrayRef(enc.getCTAsPerCGA()).drop_back(1),
ArrayRef(enc.getCTASplitNum()).drop_back(1),
ArrayRef(enc.getCTAOrder()).drop_front(1)));
return success();
}

// SplitOp takes a tensor of shape AxBxCx2 and generates two tensors of
// shape AxBxC. The input must have 2 elements per thread in the last
// dimension, which must be most-minor. The result encoding is the same as
// the input, but with the last dimension removed.
if (enc.getSizePerThread().back() != 2) {
return emitOptionalError(loc,
"SplitOp requires 2 elements per thread in the "
"last dimension of the input");
}
if (enc.getThreadsPerWarp().back() != 1 ||
enc.getWarpsPerCTA().back() != 1 || enc.getCTAsPerCGA().back() != 1) {
return emitOptionalError(
loc, "SplitOp requires threadsPerWarp, warpsPerCTA, "
"and CTAsPerCGA = 1 for the last dimension of the input");
auto axis = shape.size() - 1;
assert(shape[axis] == 2 &&
"SplitOp input shape should have 2 in the last dim");

auto ctx = getContext();

// Split on last dim
auto ll = toLinearLayout(shape, srcEnc);
auto newLl = LinearLayout::empty();
auto result =
tryJoinOnAxis(ctx, ll, newLl, /*fwdInference=*/false, axis, loc);
if (!result.succeeded()) {
return failure();
}
if (enc.getCTALayout().getCTAsPerCGA().back() != 1) {
return emitOptionalError(
loc,
"SplitOp requires the last dimension to be most-minor in CTAOrder");
}
SmallVector<unsigned> newOrder(enc.getOrder());
int splitDim = newOrder.size() - 1;
// Remove splitDim from order.
newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim),
newOrder.end());
dstEnc = BlockedEncodingAttr::get(
enc.getContext(), //
ArrayRef(enc.getSizePerThread()).drop_back(1),
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
CTALayoutAttr::get(enc.getContext(), //
ArrayRef(enc.getCTAsPerCGA()).drop_back(1),
ArrayRef(enc.getCTASplitNum()).drop_back(1),
ArrayRef(enc.getCTAOrder()).drop_front(1)));

// Remove last dim from newLl (which should be 1)
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
dstShape.pop_back();
newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape));
dstEnc = LinearEncodingAttr::get(ctx, newLl);
return success();
}

Expand Down Expand Up @@ -2873,37 +2953,10 @@ struct TritonGPUInferLayoutInterface
}

auto ll = toLinearLayout(shape, inEnc);

auto kRegister = StringAttr::get(ctx, "register");
auto outDims = llvm::to_vector(ll.getOutDimNames());
LinearLayout newLl = LinearLayout::empty();
if (fwdInference) {
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
newLl = split * ll;
} else {
// TODO This requires a division algorithm!
// Implement manually ll.divideLeft(split)
auto contiguousElems =
LinearEncodingAttr::get(ctx, ll).getContigPerThread();
if (contiguousElems[axis] > 1) {
LinearLayout::BasesT newBases;
for (const auto &basesDim : ll.getBases()) {
std::vector<std::vector<int32_t>> newBasesDim;
for (auto base : basesDim.second) {
if (base[axis] == 1) {
continue;
}
base[axis] /= 2;
newBasesDim.push_back(std::move(base));
}
newBases.insert({basesDim.first, std::move(newBasesDim)});
}
newLl = LinearLayout(std::move(newBases), std::move(outDims));
} else {
return emitOptionalError(loc, "Fp4ToFpOp requires at least 2 elements "
"per thread in the axis dimension");
}
}
auto newLl = LinearLayout::empty();
auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc);
if (!result.succeeded())
return result;
outEnc = LinearEncodingAttr::get(ctx, newLl);
return success();
}
Expand Down
Loading
Loading