From 299d3fa8dc679f30143a90d61410913042c7fc20 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 26 Jul 2023 01:13:38 -0700 Subject: [PATCH] [TRANSFORM] Fix perf regression with mmav3 and no block_ptr (#21) --- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index c60603fc3ddb6..6cb0bbcce0330 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1046,24 +1046,32 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); // Clone the loop body, replace original args with args of the new ForOp. - // We want to find cvt ops that match the following pattern: - // %0 = load %ptr - // %1 (dotOperand) = cvt %0 for (Operation &op : forOp.getBody()->without_terminator()) { if (auto cvtOp = dyn_cast(op)) { auto result = op.getResult(0); auto cvtDstTy = result.getType().cast(); - if (cvtDstTy.getEncoding().isa()) { - auto it = - std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); - if (it != validLoads.end()) { + auto it = + std::find(validLoads.begin(), validLoads.end(), op.getOperand(0)); + if (it != validLoads.end()) { + auto loadArgIdx = std::distance(validLoads.begin(), it); + if (cvtDstTy.getEncoding().isa()) { + // We want to find cvt ops that match the following pattern: + // %0 = load %ptr + // %1 (dotOperand) = cvt %0 // We replace the use new load use with a convert layout - auto loadArgIdx = std::distance(validLoads.begin(), it); auto cvt = builder.create( result.getLoc(), cvtDstTy, newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); mapping.map(result, cvt.getResult()); continue; + } else if (cvtDstTy.getEncoding().isa()) { + // We want to find cvt ops that match the following pattern: + // %0 = load %ptr + // %1 (sharedEncoding) = cvt %0 + // We replace the use new load use with insert_slice_async's result + mapping.map(result, + newForOp.getRegionIterArgs()[loadIdx + loadArgIdx]); + continue; } } } else if (auto loadOp = dyn_cast(op)) {