Skip to content

Commit

Permalink
[TRANSFORM] Fix perf regression with mmav3 and no block_ptr (triton-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jul 26, 2023
1 parent da11a2d commit 299d3fa
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1046,24 +1046,32 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());

// Clone the loop body, replace original args with args of the new ForOp.
// We want to find cvt ops that match the following pattern:
// %0 = load %ptr
// %1 (dotOperand) = cvt %0
for (Operation &op : forOp.getBody()->without_terminator()) {
if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(op)) {
auto result = op.getResult(0);
auto cvtDstTy = result.getType().cast<RankedTensorType>();
if (cvtDstTy.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
auto it =
std::find(validLoads.begin(), validLoads.end(), op.getOperand(0));
if (it != validLoads.end()) {
auto it =
std::find(validLoads.begin(), validLoads.end(), op.getOperand(0));
if (it != validLoads.end()) {
auto loadArgIdx = std::distance(validLoads.begin(), it);
if (cvtDstTy.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
// We want to find cvt ops that match the following pattern:
// %0 = load %ptr
// %1 (dotOperand) = cvt %0
// We replace the use new load use with a convert layout
auto loadArgIdx = std::distance(validLoads.begin(), it);
auto cvt = builder.create<ttg::ConvertLayoutOp>(
result.getLoc(), cvtDstTy,
newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]);
mapping.map(result, cvt.getResult());
continue;
} else if (cvtDstTy.getEncoding().isa<ttg::SharedEncodingAttr>()) {
// We want to find cvt ops that match the following pattern:
// %0 = load %ptr
// %1 (sharedEncoding) = cvt %0
// We replace the use new load use with insert_slice_async's result
mapping.map(result,
newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]);
continue;
}
}
} else if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
Expand Down

0 comments on commit 299d3fa

Please sign in to comment.