Skip to content

Commit

Permalink
[CIR][CodeGen] Flattening for ScopeOp and LoopOpInterface (llvm#546)
Browse files Browse the repository at this point in the history
This PR is the next step towards goto support and adds flattening for
`ScopeOp` and `LoopOpInterface`.

Looks like I can't separate this operations and create two PRs, since
some errors occur if I do so, e.g. `reference to block defined in
another region`. Seems we need to flatten both operations in the same
time. Given it's a copy-pasta, I think there is no need to try to make
several PRs.

I added several tests - just copied them from the lowering part just to
demonstrate how it looks like.

Note, that changes in `dot.cir` caused by `BrCondOp` updates in the
previous PR, when we removed the following casts:
```
    %20 = llvm.zext %19 : i1 to i8
    %21 = llvm.trunc %20 : i8 to i1
    llvm.cond_br %21 ...
```
  • Loading branch information
gitoleg authored and lanza committed Oct 1, 2024
1 parent 5c9116a commit cc8a208
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 176 deletions.
158 changes: 156 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ using namespace mlir::cir;

namespace {

/// Lowers operations with the terminator trait that have a single successor.
void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
mlir::PatternRewriter &rewriter) {
assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(op, dest);
}

/// Walks a region while skipping operations of type `Ops`. This ensures the
/// callback is not applied to said operations and its children.
template <typename... Ops>
void walkRegionSkipping(mlir::Region &region,
mlir::function_ref<void(mlir::Operation *)> callback) {
region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
if (isa<Ops...>(op))
return mlir::WalkResult::skip();
callback(op);
return mlir::WalkResult::advance();
});
}

struct FlattenCFGPass : public FlattenCFGBase<FlattenCFGPass> {

FlattenCFGPass() = default;
Expand Down Expand Up @@ -92,8 +114,140 @@ struct CIRIfFlattening : public OpRewritePattern<IfOp> {
}
};

class CIRScopeOpFlattening : public mlir::OpRewritePattern<mlir::cir::ScopeOp> {
public:
using OpRewritePattern<mlir::cir::ScopeOp>::OpRewritePattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::ScopeOp scopeOp,
mlir::PatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
auto loc = scopeOp.getLoc();

// Empty scope: just remove it.
if (scopeOp.getRegion().empty()) {
rewriter.eraseOp(scopeOp);
return mlir::success();
}

// Split the current block before the ScopeOp to create the inlining
// point.
auto *currentBlock = rewriter.getInsertionBlock();
auto *remainingOpsBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
mlir::Block *continueBlock;
if (scopeOp.getNumResults() == 0)
continueBlock = remainingOpsBlock;
else
llvm_unreachable("NYI");

// Inline body region.
auto *beforeBody = &scopeOp.getRegion().front();
auto *afterBody = &scopeOp.getRegion().back();
rewriter.inlineRegionBefore(scopeOp.getRegion(), continueBlock);

// Save stack and then branch into the body of the region.
rewriter.setInsertionPointToEnd(currentBlock);
// TODO(CIR): stackSaveOp
// auto stackSaveOp = rewriter.create<mlir::LLVM::StackSaveOp>(
// loc, mlir::LLVM::LLVMPointerType::get(
// mlir::IntegerType::get(scopeOp.getContext(), 8)));
rewriter.create<mlir::cir::BrOp>(loc, mlir::ValueRange(), beforeBody);

// Replace the scopeop return with a branch that jumps out of the body.
// Stack restore before leaving the body region.
rewriter.setInsertionPointToEnd(afterBody);
if (auto yieldOp =
dyn_cast<mlir::cir::YieldOp>(afterBody->getTerminator())) {
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldOp, yieldOp.getArgs(),
continueBlock);
}

// TODO(cir): stackrestore?

// Replace the op with values return from the body region.
rewriter.replaceOp(scopeOp, continueBlock->getArguments());

return mlir::success();
}
};

class CIRLoopOpInterfaceFlattening
: public mlir::OpInterfaceRewritePattern<mlir::cir::LoopOpInterface> {
public:
using mlir::OpInterfaceRewritePattern<
mlir::cir::LoopOpInterface>::OpInterfaceRewritePattern;

inline void lowerConditionOp(mlir::cir::ConditionOp op, mlir::Block *body,
mlir::Block *exit,
mlir::PatternRewriter &rewriter) const {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrCondOp>(op, op.getCondition(),
body, exit);
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::LoopOpInterface op,
mlir::PatternRewriter &rewriter) const final {
// Setup CFG blocks.
auto *entry = rewriter.getInsertionBlock();
auto *exit = rewriter.splitBlock(entry, rewriter.getInsertionPoint());
auto *cond = &op.getCond().front();
auto *body = &op.getBody().front();
auto *step = (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);

// Setup loop entry branch.
rewriter.setInsertionPointToEnd(entry);
rewriter.create<mlir::cir::BrOp>(op.getLoc(), &op.getEntry().front());

// Branch from condition region to body or exit.
auto conditionOp = cast<mlir::cir::ConditionOp>(cond->getTerminator());
lowerConditionOp(conditionOp, body, exit, rewriter);

// TODO(cir): Remove the walks below. It visits operations unnecessarily,
// however, to solve this we would likely need a custom DialecConversion
// driver to customize the order that operations are visited.

// Lower continue statements.
mlir::Block *dest = (step ? step : cond);
op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
if (isa<mlir::cir::ContinueOp>(op))
lowerTerminator(op, dest, rewriter);
});

// Lower break statements.
walkRegionSkipping<mlir::cir::LoopOpInterface, mlir::cir::SwitchOp>(
op.getBody(), [&](mlir::Operation *op) {
if (isa<mlir::cir::BreakOp>(op))
lowerTerminator(op, exit, rewriter);
});

// Lower optional body region yield.
auto bodyYield = dyn_cast<mlir::cir::YieldOp>(body->getTerminator());
if (bodyYield)
lowerTerminator(bodyYield, (step ? step : cond), rewriter);

// Lower mandatory step region yield.
if (step)
lowerTerminator(cast<mlir::cir::YieldOp>(step->getTerminator()), cond,
rewriter);

// Move region contents out of the loop op.
rewriter.inlineRegionBefore(op.getCond(), exit);
rewriter.inlineRegionBefore(op.getBody(), exit);
if (step)
rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);

rewriter.eraseOp(op);
return mlir::success();
}
};

void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
patterns.add<CIRIfFlattening>(patterns.getContext());
patterns
.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>(
patterns.getContext());
}

void FlattenCFGPass::runOnOperation() {
Expand All @@ -103,7 +257,7 @@ void FlattenCFGPass::runOnOperation() {
// Collect operations to apply patterns.
SmallVector<Operation *, 16> ops;
getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
if (isa<IfOp>(op))
if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
ops.push_back(op);
});

Expand Down
168 changes: 17 additions & 151 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,80 +481,6 @@ class CIRPtrStrideOpLowering
}
};

class CIRLoopOpInterfaceLowering
: public mlir::OpInterfaceConversionPattern<mlir::cir::LoopOpInterface> {
public:
using mlir::OpInterfaceConversionPattern<
mlir::cir::LoopOpInterface>::OpInterfaceConversionPattern;

inline void
lowerConditionOp(mlir::cir::ConditionOp op, mlir::Block *body,
mlir::Block *exit,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrCondOp>(op, op.getCondition(),
body, exit);
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::LoopOpInterface op,
mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const final {
// Setup CFG blocks.
auto *entry = rewriter.getInsertionBlock();
auto *exit = rewriter.splitBlock(entry, rewriter.getInsertionPoint());
auto *cond = &op.getCond().front();
auto *body = &op.getBody().front();
auto *step = (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);

// Setup loop entry branch.
rewriter.setInsertionPointToEnd(entry);
rewriter.create<mlir::LLVM::BrOp>(op.getLoc(), &op.getEntry().front());

// Branch from condition region to body or exit.
auto conditionOp = cast<mlir::cir::ConditionOp>(cond->getTerminator());
lowerConditionOp(conditionOp, body, exit, rewriter);

// TODO(cir): Remove the walks below. It visits operations unnecessarily,
// however, to solve this we would likely need a custom DialecConversion
// driver to customize the order that operations are visited.

// Lower continue statements.
mlir::Block *dest = (step ? step : cond);
op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
if (isa<mlir::cir::ContinueOp>(op))
lowerTerminator(op, dest, rewriter);
});

// Lower break statements.
walkRegionSkipping<mlir::cir::LoopOpInterface, mlir::cir::SwitchOp>(
op.getBody(), [&](mlir::Operation *op) {
if (isa<mlir::cir::BreakOp>(op))
lowerTerminator(op, exit, rewriter);
});

// Lower optional body region yield.
auto bodyYield = dyn_cast<mlir::cir::YieldOp>(body->getTerminator());
if (bodyYield)
lowerTerminator(bodyYield, (step ? step : cond), rewriter);

// Lower mandatory step region yield.
if (step)
lowerTerminator(cast<mlir::cir::YieldOp>(step->getTerminator()), cond,
rewriter);

// Move region contents out of the loop op.
rewriter.inlineRegionBefore(op.getCond(), exit);
rewriter.inlineRegionBefore(op.getBody(), exit);
if (step)
rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);

rewriter.eraseOp(op);
return mlir::success();
}
};

class CIRBrCondOpLowering
: public mlir::OpConversionPattern<mlir::cir::BrCondOp> {
public:
Expand Down Expand Up @@ -785,65 +711,6 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
}
};

class CIRScopeOpLowering
: public mlir::OpConversionPattern<mlir::cir::ScopeOp> {
public:
using OpConversionPattern<mlir::cir::ScopeOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::ScopeOp scopeOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::OpBuilder::InsertionGuard guard(rewriter);
auto loc = scopeOp.getLoc();

// Empty scope: just remove it.
if (scopeOp.getRegion().empty()) {
rewriter.eraseOp(scopeOp);
return mlir::success();
}

// Split the current block before the ScopeOp to create the inlining
// point.
auto *currentBlock = rewriter.getInsertionBlock();
auto *remainingOpsBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
mlir::Block *continueBlock;
if (scopeOp.getNumResults() == 0)
continueBlock = remainingOpsBlock;
else
llvm_unreachable("NYI");

// Inline body region.
auto *beforeBody = &scopeOp.getRegion().front();
auto *afterBody = &scopeOp.getRegion().back();
rewriter.inlineRegionBefore(scopeOp.getRegion(), continueBlock);

// Save stack and then branch into the body of the region.
rewriter.setInsertionPointToEnd(currentBlock);
// TODO(CIR): stackSaveOp
// auto stackSaveOp = rewriter.create<mlir::LLVM::StackSaveOp>(
// loc, mlir::LLVM::LLVMPointerType::get(
// mlir::IntegerType::get(scopeOp.getContext(), 8)));
rewriter.create<mlir::cir::BrOp>(loc, mlir::ValueRange(), beforeBody);

// Replace the scopeop return with a branch that jumps out of the body.
// Stack restore before leaving the body region.
rewriter.setInsertionPointToEnd(afterBody);
if (auto yieldOp =
dyn_cast<mlir::cir::YieldOp>(afterBody->getTerminator())) {
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldOp, yieldOp.getArgs(),
continueBlock);
}

// TODO(cir): stackrestore?

// Replace the op with values return from the body region.
rewriter.replaceOp(scopeOp, continueBlock->getArguments());

return mlir::success();
}
};

class CIRReturnLowering
: public mlir::OpConversionPattern<mlir::cir::ReturnOp> {
public:
Expand Down Expand Up @@ -3077,23 +2944,22 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRCmpOpLowering, CIRBitClrsbOpLowering, CIRBitClzOpLowering,
CIRBitCtzOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
CIRBitPopcountOpLowering, CIRAtomicFetchLowering, CIRByteswapOpLowering,
CIRLoopOpInterfaceLowering, CIRBrCondOpLowering, CIRPtrStrideOpLowering,
CIRCallLowering, CIRUnaryOpLowering, CIRBinOpLowering, CIRShiftOpLowering,
CIRLoadLowering, CIRConstantLowering, CIRStoreLowering, CIRAllocaLowering,
CIRFuncLowering, CIRScopeOpLowering, CIRCastOpLowering,
CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRVAStartLowering,
CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering,
CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering,
CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering,
CIRFAbsOpLowering, CIRExpectOpLowering, CIRVTableAddrPointOpLowering,
CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRVectorSplatLowering,
CIRVectorTernaryLowering, CIRVectorShuffleIntsLowering,
CIRVectorShuffleVecLowering, CIRStackSaveLowering,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering>(
converter, patterns.getContext());
CIRBrCondOpLowering, CIRPtrStrideOpLowering, CIRCallLowering,
CIRUnaryOpLowering, CIRBinOpLowering, CIRShiftOpLowering, CIRLoadLowering,
CIRConstantLowering, CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering,
CIRCastOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering,
CIRBrOpLowering, CIRTernaryOpLowering, CIRGetMemberOpLowering,
CIRSwitchOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering,
CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering,
CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRVectorSplatLowering, CIRVectorTernaryLowering,
CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering,
CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering,
CIRTrapLowering, CIRInlineAsmOpLowering, CIRSetBitfieldLowering,
CIRGetBitfieldLowering, CIRPrefetchLowering, CIRObjSizeOpLowering,
CIRIsConstantOpLowering>(converter, patterns.getContext());
}

namespace {
Expand Down Expand Up @@ -3261,7 +3127,7 @@ static void buildCtorDtorList(
// pass it will be placed into the unreachable block. And the possible error
// after the lowering pass is: error: 'cir.return' op expects parent op to be
// one of 'cir.func, cir.scope, cir.if ... The reason that this operation was
// not lowered and the new parent is lllvm.func.
// not lowered and the new parent is llvm.func.
//
// In the future we may want to get rid of this function and use DCE pass or
// something similar. But now we need to guarantee the absence of the dialect
Expand Down
Loading

0 comments on commit cc8a208

Please sign in to comment.