diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 5e4dc5ccc9e75..46cba0b1dae1b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -129,9 +129,6 @@ class LoopPipeliner { /// Loads to be pipelined SetVector validLoads; - /// FIXME(Keren): Check if this is required - /// Number of loads that requires AsyncWait for synchronization - int numLoadsRequiresAsyncWait = 0; /// The value that each load will be mapped to (after layout conversion) DenseMap loadsMapping; /// load => buffer @@ -143,7 +140,9 @@ class LoopPipeliner { /// load => after extract DenseMap loadsExtract; - /// XXX(Keren): The following are tma specific and disabled + /// XXX(Keren): The following are h100 only and disabled + Value curWaitIdx; + Value curPhase; /// load => full barrier arrive DenseMap loadsBarrierArvOp; /// load => mbarriers @@ -151,6 +150,15 @@ class LoopPipeliner { DenseMap loadsEmptyBarriers; /// load => null value or previous load which can share barrier with DenseMap loadsCanShareBarriers; + /// Maintains the information to emit consumer_release mbarrier_arrive + ConsumerReleaseMap &consumerReleaseMap; + bool hasHopperDot = false; + // XXX(Keren): why the variable name is hopper dot and why do we need this + // check? + void checkHopperDots(SetVector &ops); + // XXX(Keren): it looks more like an optimization to be, not sure if it should + // exist in the base pipeliner + void checkOpShareBarriers(SetVector &ops); /// Iterator values Value pipelineIterIdx; @@ -162,13 +170,6 @@ class LoopPipeliner { SmallVector extractSlices; SmallVector yieldValues; - /// XXX(Keren): The following are tma specific and disabled - Value curWaitIdx; - Value curPhase; - /// Maintains the information to emit consumer_release mbarrier_arrive - ConsumerReleaseMap &consumerReleaseMap; - bool hasHopperDot = false; - /// The number of stages in the pipeline. /// Stages in the range of [0, numStages-1) are in the prologue. /// numStages-1 is appended after the loop body. @@ -221,14 +222,6 @@ class LoopPipeliner { /// Check if none of the ops has valid uses LogicalResult checkOpUses(SetVector &ops); - // XXX(Keren): why the variable name is hopper dot and why do we need this - // check? - void checkHopperDots(SetVector &ops); - - // XXX(Keren): it looks more like an optimization to be, not sure if it should - // exist in the base pipeliner - void checkOpShareBarriers(SetVector &ops); - /// Check if ops have dependencies that are not pipelinable void checkOpDeps(SetVector &ops); @@ -271,8 +264,10 @@ class LoopPipeliner { /// Prefetch the next iteration for `newForOp` void prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder); - /// Check if next iteration is out of bounary - SmallVector getNextPhase(OpBuilder &builder, Value curIdx, Value upperBoundIdx); + /// Check if curIdx is out of bound and wrap value around if necessary + Value getNextIterationValue(OpBuilder &builder, Value curIdx, + Value upperBoundIdx, Value curValue, + Value initValue); /// Assemble `newForOp`'s yield op void finalizeYield(scf::ForOp newForOp, OpBuilder &builder); @@ -455,11 +450,8 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector &ops) { if (!isCandidate) invalidOps.insert(loadOp); - else { + else validLoads.insert(loadOp); - if (!isLoadFromTensorPtr(loadOp)) - numLoadsRequiresAsyncWait++; - } } } @@ -974,16 +966,18 @@ void LoopPipeliner::emitPrologue() { loopIterIdx = builder.create( loopIterIdx.getLoc(), loopIterIdx, builder.create(loopIterIdx.getLoc(), 1, 32)); - //curWaitIdx = builder.create(iv.getLoc(), 0, 32); - //curPhase = builder.create(iv.getLoc(), 0, 1); + // curWaitIdx = builder.create(iv.getLoc(), 0, 32); + // curPhase = builder.create(iv.getLoc(), 0, 1); } void LoopPipeliner::emitEpilogue() { // If there's any outstanding async copies, we need to wait for them. - OpBuilder builder(forOp); - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfter(forOp); - builder.create(forOp.getLoc(), 0); + if (validLoads.size() > 0) { + OpBuilder builder(forOp); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfter(forOp); + builder.create(forOp.getLoc(), 0); + } } SmallVector LoopPipeliner::collectNewLoopArgs() { @@ -1028,8 +1022,8 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]); newLoopArgs.push_back(pipelineIterIdx); newLoopArgs.push_back(loopIterIdx); - //newLoopArgs.push_back(curWaitIdx); - //newLoopArgs.push_back(curPhase); + // newLoopArgs.push_back(curWaitIdx); + // newLoopArgs.push_back(curPhase); return newLoopArgs; } @@ -1125,29 +1119,24 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef newLoopArgs, return newForOp; } -SmallVector LoopPipeliner::getNextPhase(OpBuilder &builder, Value curIdx, - Value upperBoundIdx) { +Value LoopPipeliner::getNextIterationValue(OpBuilder &builder, Value curIdx, + Value upperBoundIdx, Value curValue, + Value initValue) { Value cond = builder.create( curIdx.getLoc(), arith::CmpIPredicate::uge, curIdx, upperBoundIdx); - auto nextIter = builder.create( - curIdx.getLoc(), ArrayRef{curIdx.getType(), curPhase.getType()}, - cond, true); - + auto nextIter = builder.create(curIdx.getLoc(), curValue.getType(), + cond, true); // True branch + auto insertionPoint = builder.saveInsertionPoint(); builder.setInsertionPointToStart(nextIter.thenBlock()); - Value newIdx = builder.create(cond.getLoc(), 0, 32); - Value newPhase = builder.create( - forOp.getLoc(), curPhase, - builder.create(forOp.getLoc(), 1, 1)); - builder.create(nextIV.getLoc(), - ArrayRef{newIdx, newPhase}); + builder.create(curIdx.getLoc(), initValue); // False branch builder.setInsertionPointToStart(nextIter.elseBlock()); - builder.create(curIdx.getLoc(), - ArrayRef{curIdx, curPhase}); + builder.create(curIdx.getLoc(), curValue); + builder.restoreInsertionPoint(insertionPoint); - return nextIter.getResults(); + return nextIter.getResult(0); } void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, @@ -1170,13 +1159,14 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, nextIV, newForOp.getUpperBound()); pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1]; - Value insertSliceIndex = builder.create( - nextIV.getLoc(), pipelineIterIdx, - builder.create(nextIV.getLoc(), numStages, 32)); + Value numStagesVal = + builder.create(nextIV.getLoc(), numStages, 32); + Value initVal = builder.create(nextIV.getLoc(), 0, 32); + Value insertSliceIndex = getNextIterationValue( + builder, pipelineIterIdx, numStagesVal, pipelineIterIdx, initVal); loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 2]; - Value extractSliceIndex = builder.create( - nextIV.getLoc(), loopIterIdx, - builder.create(nextIV.getLoc(), numStages, 32)); + Value extractSliceIndex = getNextIterationValue( + builder, loopIterIdx, numStagesVal, loopIterIdx, initVal); // Prefetch load deps // If a load-dependent instruction that uses a block argument, we @@ -1250,8 +1240,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, // XXX(Keren): might be wrong for tma // else // newMask = builder.create( - // loadOp.getLoc(), mlir::triton::getI1SameShape(loadOp.getType()), - // nextLoopCond); + // loadOp.getLoc(), + // mlir::triton::getI1SameShape(loadOp.getType()), nextLoopCond); } Value insertedVal; if (mode && isLoadFromTensorPtr(loadOp)) { @@ -1370,6 +1360,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, } // Bump iteration count + pipelineIterIdx = insertSliceIndex; + loopIterIdx = extractSliceIndex; pipelineIterIdx = builder.create( nextIV.getLoc(), pipelineIterIdx, builder.create(nextIV.getLoc(), 1, 32)); @@ -1377,10 +1369,10 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp, nextIV.getLoc(), loopIterIdx, builder.create(nextIV.getLoc(), 1, 32)); // FIXME(Keren): Reenable after tma is fixed - //curWaitIdx = builder.create( + // curWaitIdx = builder.create( // forOp.getLoc(), curWaitIdx, // builder.create(forOp.getLoc(), 1, 32)); - //curPhase = builder.create( + // curPhase = builder.create( // forOp.getLoc(), curPhase, // builder.create(forOp.getLoc(), 1, 1)); } @@ -1403,8 +1395,8 @@ void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) { yieldValues.push_back(nextIV); yieldValues.push_back(pipelineIterIdx); yieldValues.push_back(loopIterIdx); - //yieldValues.push_back(curWaitIdx); - //yieldValues.push_back(curPhase); + // yieldValues.push_back(curWaitIdx); + // yieldValues.push_back(curPhase); builder.setInsertionPointToEnd(newForOp.getBody()); builder.create(yieldOp->getLoc(), yieldValues); @@ -1437,9 +1429,7 @@ struct PipelinePass : public TritonGPUPipelineBase { // prologue is currently not properly provided. Need some second thought on // the mask definition of InsertSliceOp when the src is ptr bool mode = (computeCapability >= 90); - int numStages = this->numStages; - - if (numStages <= 1) + if (this->numStages <= 1) return; // phase 0: pipeline loads in loops @@ -1455,9 +1445,8 @@ struct PipelinePass : public TritonGPUPipelineBase { // Do the pipelining getOperation()->walk([&](scf::ForOp forOp) -> void { - LoopPipeliner pipeliner(forOp, numStages, this->numWarps, this->numCTAs, - mode, consumerReleaseMap); - + LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps, + this->numCTAs, mode, consumerReleaseMap); if (pipeliner.initialize().failed()) return; @@ -1674,7 +1663,6 @@ void PipelinePass::emitConsumerRelease(Value mbarTensor, b.setInsertionPointAfter(lastUserWithLargestStage); auto loc = lastUserWithLargestStage->getLoc(); auto maxStageVal = b.create(loc, maxStage, 32); - auto numStagesVal = b.create(loc, numStages, 32); // pred = (iterVar >= maxStage) && // (threadId % (numConsumerThreads / numRemoteCTAs) == 0); diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 97078f3e0bded..f01731a37c198 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -37,15 +37,17 @@ // CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}} -// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX]] +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[LOOP_IDX]] // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[INSERT_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[EXTRACT_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -110,15 +112,17 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX]] +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[LOOP_IDX]] // CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 2 : i32} // CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0] // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[INSERT_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[EXTRACT_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, @@ -179,13 +183,15 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]] // CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]] // CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}} -// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]] -// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX]] +// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[LOOP_IDX]], %[[CONSTANT_3]] +// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[LOOP_IDX]] // CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]] // CHECK: triton_gpu.async_wait {num = 1 : i32} // CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0] -// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] -// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[INSERT_IDX]], %[[CONSTANT_1]] +// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[EXTRACT_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32},