Skip to content

Commit

Permalink
further review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Dec 11, 2024
1 parent e1ecc84 commit a85fb26
Showing 1 changed file with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ void AMDAIEFuseFillIntoForallPass::runOnOperation() {

// Find a unique FillOp with a single output, or return.
SmallVector<linalg::FillOp> fillOps;
getOperation()->walk([&](linalg::FillOp fillOp) { fillOps.push_back(fillOp); });
getOperation()->walk(
[&](linalg::FillOp fillOp) { fillOps.push_back(fillOp); });
if (fillOps.size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "Expected exactly 1 fill op, but found "
<< fillOps.size() << ".\n");
Expand All @@ -51,18 +52,16 @@ void AMDAIEFuseFillIntoForallPass::runOnOperation() {

// Confirm that there is a unique user that is a forall, and match
// the block argument that is used by the fill op, or return.
ResultRange::use_range fillUses = fillOp->getUses();
if (std::distance(fillUses.begin(), fillUses.end()) != 1) {
if (!fillOp->hasOneUse()) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected exactly 1 use of fill op, but found "
<< std::distance(fillUses.begin(), fillUses.end()) << ".\n");
<< "Expected exactly 1 use of fill op, but found 0 or 2+.");
return;
}
OpOperand &fillUse = *fillUses.begin();
OpOperand &fillUse = *fillOp->getUses().begin();
auto forallOp = dyn_cast<scf::ForallOp>(fillUse.getOwner());
if (!forallOp) {
LLVM_DEBUG(llvm::dbgs() << "Expected fill op to be used by a forall op, "
<< "but the only user is "
<< "but unique user is "
<< fillUse.getOwner()->getName() << ".\n");
return;
}
Expand Down Expand Up @@ -99,9 +98,8 @@ void AMDAIEFuseFillIntoForallPass::runOnOperation() {
// if the extract_slice has been folded, for example if the forall is
// over a grid if size 1.
rewriter.setInsertionPointToStart(forallOp.getBody());
Value scalar = fillOp.value();
Location loc = fillOp.getLoc();
auto fusedFill = rewriter.create<linalg::FillOp>(loc, scalar, bbArg);
auto fusedFill =
rewriter.create<linalg::FillOp>(fillOp.getLoc(), fillOp.value(), bbArg);
rewriter.replaceUsesWithIf(
bbArg, fusedFill.getResult(0), [&](OpOperand &operand) {
Operation *owner = operand.getOwner();
Expand Down

0 comments on commit a85fb26

Please sign in to comment.