Skip to content

Commit

Permalink
Merge branch llvm-head (triton-lang#1600)
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg authored May 1, 2023
1 parent f98bb96 commit ed37d5b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
7 changes: 4 additions & 3 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def TT_LoadOp : TT_Op<"load",
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor pointer with boundary check and padding
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
Expand All @@ -164,8 +164,9 @@ def TT_LoadOp : TT_Op<"load",
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A utility function to build the operation with all attributes
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "Optional<ArrayRef<int32_t>>":$boundaryCheck,
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
"std::optional<ArrayRef<int32_t>>":$boundaryCheck,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
];

Expand Down
8 changes: 2 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,6 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
}

void LoopPipeliner::emitPrologue() {
// llvm::errs() << "loads to pipeline...:\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";

OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
Expand All @@ -392,7 +388,7 @@ void LoopPipeliner::emitPrologue() {
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
Expand Down Expand Up @@ -601,7 +597,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
Expand Down

0 comments on commit ed37d5b

Please sign in to comment.