Skip to content

Commit

Permalink
[TRANSFORM] Use scf.if for boundary checks (triton-lang#12)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update
  • Loading branch information
Jokeren authored Jul 20, 2023
1 parent 5802f75 commit e2fca48
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 79 deletions.
122 changes: 55 additions & 67 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ class LoopPipeliner {

/// Loads to be pipelined
SetVector<Value> validLoads;
/// FIXME(Keren): Check if this is required
/// Number of loads that requires AsyncWait for synchronization
int numLoadsRequiresAsyncWait = 0;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
Expand All @@ -143,14 +140,25 @@ class LoopPipeliner {
/// load => after extract
DenseMap<Value, Value> loadsExtract;

/// XXX(Keren): The following are tma specific and disabled
/// XXX(Keren): The following are h100 only and disabled
Value curWaitIdx;
Value curPhase;
/// load => full barrier arrive
DenseMap<Value, Operation *> loadsBarrierArvOp;
/// load => mbarriers
DenseMap<Value, Value> loadsFullBarriers;
DenseMap<Value, Value> loadsEmptyBarriers;
/// load => null value or previous load which can share barrier with
DenseMap<Value, Value> loadsCanShareBarriers;
/// Maintains the information to emit consumer_release mbarrier_arrive
ConsumerReleaseMap &consumerReleaseMap;
bool hasHopperDot = false;
// XXX(Keren): why the variable name is hopper dot and why do we need this
// check?
void checkHopperDots(SetVector<Operation *> &ops);
// XXX(Keren): it looks more like an optimization to be, not sure if it should
// exist in the base pipeliner
void checkOpShareBarriers(SetVector<Operation *> &ops);

/// Iterator values
Value pipelineIterIdx;
Expand All @@ -162,13 +170,6 @@ class LoopPipeliner {
SmallVector<Value> extractSlices;
SmallVector<Value> yieldValues;

/// XXX(Keren): The following are tma specific and disabled
Value curWaitIdx;
Value curPhase;
/// Maintains the information to emit consumer_release mbarrier_arrive
ConsumerReleaseMap &consumerReleaseMap;
bool hasHopperDot = false;

/// The number of stages in the pipeline.
/// Stages in the range of [0, numStages-1) are in the prologue.
/// numStages-1 is appended after the loop body.
Expand Down Expand Up @@ -221,14 +222,6 @@ class LoopPipeliner {
/// Check if none of the ops has valid uses
LogicalResult checkOpUses(SetVector<Operation *> &ops);

// XXX(Keren): why the variable name is hopper dot and why do we need this
// check?
void checkHopperDots(SetVector<Operation *> &ops);

// XXX(Keren): it looks more like an optimization to be, not sure if it should
// exist in the base pipeliner
void checkOpShareBarriers(SetVector<Operation *> &ops);

/// Check if ops have dependencies that are not pipelinable
void checkOpDeps(SetVector<Operation *> &ops);

Expand Down Expand Up @@ -271,8 +264,10 @@ class LoopPipeliner {
/// Prefetch the next iteration for `newForOp`
void prefetchNextIteration(scf::ForOp newForOp, OpBuilder &builder);

/// Check if next iteration is out of bounary
SmallVector<Value> getNextPhase(OpBuilder &builder, Value curIdx, Value upperBoundIdx);
/// Check if curIdx is out of bound and wrap value around if necessary
Value getNextIterationValue(OpBuilder &builder, Value curIdx,
Value upperBoundIdx, Value curValue,
Value initValue);

/// Assemble `newForOp`'s yield op
void finalizeYield(scf::ForOp newForOp, OpBuilder &builder);
Expand Down Expand Up @@ -455,11 +450,8 @@ LogicalResult LoopPipeliner::checkOpUses(SetVector<Operation *> &ops) {

if (!isCandidate)
invalidOps.insert(loadOp);
else {
else
validLoads.insert(loadOp);
if (!isLoadFromTensorPtr(loadOp))
numLoadsRequiresAsyncWait++;
}
}
}

Expand Down Expand Up @@ -974,16 +966,18 @@ void LoopPipeliner::emitPrologue() {
loopIterIdx = builder.create<arith::AddIOp>(
loopIterIdx.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(loopIterIdx.getLoc(), 1, 32));
//curWaitIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
//curPhase = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 1);
// curWaitIdx = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
// curPhase = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 1);
}

void LoopPipeliner::emitEpilogue() {
// If there's any outstanding async copies, we need to wait for them.
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
if (validLoads.size() > 0) {
OpBuilder builder(forOp);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointAfter(forOp);
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
}
}

SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
Expand Down Expand Up @@ -1028,8 +1022,8 @@ SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
newLoopArgs.push_back(valueMapping[forOp.getInductionVar()][numStages - 2]);
newLoopArgs.push_back(pipelineIterIdx);
newLoopArgs.push_back(loopIterIdx);
//newLoopArgs.push_back(curWaitIdx);
//newLoopArgs.push_back(curPhase);
// newLoopArgs.push_back(curWaitIdx);
// newLoopArgs.push_back(curPhase);

return newLoopArgs;
}
Expand Down Expand Up @@ -1125,29 +1119,24 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
return newForOp;
}

SmallVector<Value> LoopPipeliner::getNextPhase(OpBuilder &builder, Value curIdx,
Value upperBoundIdx) {
Value LoopPipeliner::getNextIterationValue(OpBuilder &builder, Value curIdx,
Value upperBoundIdx, Value curValue,
Value initValue) {
Value cond = builder.create<arith::CmpIOp>(
curIdx.getLoc(), arith::CmpIPredicate::uge, curIdx, upperBoundIdx);
auto nextIter = builder.create<scf::IfOp>(
curIdx.getLoc(), ArrayRef<Type>{curIdx.getType(), curPhase.getType()},
cond, true);

auto nextIter = builder.create<scf::IfOp>(curIdx.getLoc(), curValue.getType(),
cond, true);
// True branch
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointToStart(nextIter.thenBlock());
Value newIdx = builder.create<arith::ConstantIntOp>(cond.getLoc(), 0, 32);
Value newPhase = builder.create<arith::XOrIOp>(
forOp.getLoc(), curPhase,
builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 1));
builder.create<scf::YieldOp>(nextIV.getLoc(),
ArrayRef<Value>{newIdx, newPhase});
builder.create<scf::YieldOp>(curIdx.getLoc(), initValue);

// False branch
builder.setInsertionPointToStart(nextIter.elseBlock());
builder.create<scf::YieldOp>(curIdx.getLoc(),
ArrayRef<Value>{curIdx, curPhase});
builder.create<scf::YieldOp>(curIdx.getLoc(), curValue);
builder.restoreInsertionPoint(insertionPoint);

return nextIter.getResults();
return nextIter.getResult(0);
}

void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
Expand All @@ -1170,13 +1159,14 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
nextIV, newForOp.getUpperBound());

pipelineIterIdx = newForOp.getRegionIterArgs()[ivIdx + 1];
Value insertSliceIndex = builder.create<arith::RemUIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
Value numStagesVal =
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32);
Value initVal = builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 0, 32);
Value insertSliceIndex = getNextIterationValue(
builder, pipelineIterIdx, numStagesVal, pipelineIterIdx, initVal);
loopIterIdx = newForOp.getRegionIterArgs()[ivIdx + 2];
Value extractSliceIndex = builder.create<arith::RemUIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
Value extractSliceIndex = getNextIterationValue(
builder, loopIterIdx, numStagesVal, loopIterIdx, initVal);

// Prefetch load deps
// If a load-dependent instruction that uses a block argument, we
Expand Down Expand Up @@ -1250,8 +1240,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
// XXX(Keren): might be wrong for tma
// else
// newMask = builder.create<triton::SplatOp>(
// loadOp.getLoc(), mlir::triton::getI1SameShape(loadOp.getType()),
// nextLoopCond);
// loadOp.getLoc(),
// mlir::triton::getI1SameShape(loadOp.getType()), nextLoopCond);
}
Value insertedVal;
if (mode && isLoadFromTensorPtr(loadOp)) {
Expand Down Expand Up @@ -1370,17 +1360,19 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
}

// Bump iteration count
pipelineIterIdx = insertSliceIndex;
loopIterIdx = extractSliceIndex;
pipelineIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
loopIterIdx = builder.create<arith::AddIOp>(
nextIV.getLoc(), loopIterIdx,
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), 1, 32));
// FIXME(Keren): Reenable after tma is fixed
//curWaitIdx = builder.create<arith::AddIOp>(
// curWaitIdx = builder.create<arith::AddIOp>(
// forOp.getLoc(), curWaitIdx,
// builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 32));
//curPhase = builder.create<arith::XOrIOp>(
// curPhase = builder.create<arith::XOrIOp>(
// forOp.getLoc(), curPhase,
// builder.create<arith::ConstantIntOp>(forOp.getLoc(), 1, 1));
}
Expand All @@ -1403,8 +1395,8 @@ void LoopPipeliner::finalizeYield(scf::ForOp newForOp, OpBuilder &builder) {
yieldValues.push_back(nextIV);
yieldValues.push_back(pipelineIterIdx);
yieldValues.push_back(loopIterIdx);
//yieldValues.push_back(curWaitIdx);
//yieldValues.push_back(curPhase);
// yieldValues.push_back(curWaitIdx);
// yieldValues.push_back(curPhase);

builder.setInsertionPointToEnd(newForOp.getBody());
builder.create<scf::YieldOp>(yieldOp->getLoc(), yieldValues);
Expand Down Expand Up @@ -1437,9 +1429,7 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
// prologue is currently not properly provided. Need some second thought on
// the mask definition of InsertSliceOp when the src is ptr<tensor>
bool mode = (computeCapability >= 90);
int numStages = this->numStages;

if (numStages <= 1)
if (this->numStages <= 1)
return;

// phase 0: pipeline loads in loops
Expand All @@ -1455,9 +1445,8 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {

// Do the pipelining
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, numStages, this->numWarps, this->numCTAs,
mode, consumerReleaseMap);

LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
this->numCTAs, mode, consumerReleaseMap);
if (pipeliner.initialize().failed())
return;

Expand Down Expand Up @@ -1674,7 +1663,6 @@ void PipelinePass::emitConsumerRelease(Value mbarTensor,
b.setInsertionPointAfter(lastUserWithLargestStage);
auto loc = lastUserWithLargestStage->getLoc();
auto maxStageVal = b.create<arith::ConstantIntOp>(loc, maxStage, 32);
auto numStagesVal = b.create<arith::ConstantIntOp>(loc, numStages, 32);

// pred = (iterVar >= maxStage) &&
// (threadId % (numConsumerThreads / numRemoteCTAs) == 0);
Expand Down
30 changes: 18 additions & 12 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX]]
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[LOOP_IDX]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[INSERT_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[EXTRACT_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
Expand Down Expand Up @@ -110,15 +112,17 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX]]
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[LOOP_IDX]]
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 2 : i32}
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[INSERT_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[EXTRACT_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
Expand Down Expand Up @@ -179,13 +183,15 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[CMP_PIPELINE:.*]] = arith.cmpi uge, %[[PIPELINE_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.select %[[CMP_PIPELINE]], %[[CONSTANT_0]], %[[PIPELINE_IDX]]
// CHECK-DAG: %[[CMP_LOOP:.*]] = arith.cmpi uge, %[[LOOP_IDX]], %[[CONSTANT_3]]
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.select %[[CMP_LOOP]], %[[CONSTANT_0]], %[[LOOP_IDX]]
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[INSERT_IDX]], %[[CONSTANT_1]]
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[EXTRACT_IDX]], %[[CONSTANT_1]]
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
Expand Down

0 comments on commit e2fca48

Please sign in to comment.