Skip to content

Commit

Permalink
[CIR[CIRGen][NFC] Refactor build switch op (llvm#552)
Browse files Browse the repository at this point in the history
Make logic cleaner and more extensible.

Separate collecting `SwitchStmt` information and building op logic into
different functions.
Add more UT to cover nested switch, which also worked before this pr.

This pr is split from llvm#528.
  • Loading branch information
wenpen authored and lanza committed Oct 1, 2024
1 parent 6a77238 commit f7521fd
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 58 deletions.
32 changes: 23 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,18 +1085,23 @@ class CIRGenFunction : public CIRGenTypeCache {
template <typename T>
mlir::LogicalResult
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os);
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down Expand Up @@ -1964,7 +1969,7 @@ class CIRGenFunction : public CIRGenTypeCache {
// have their own scopes but are distinct regions nonetheless.
llvm::SmallVector<mlir::Block *> RetBlocks;
llvm::SmallVector<std::optional<mlir::Location>> RetLocs;
unsigned int CurrentSwitchRegionIdx = -1;
llvm::SmallVector<std::unique_ptr<mlir::Region>> SwitchRegions;

// There's usually only one ret block per scope, but this needs to be
// get or create because of potential unreachable return statements, note
Expand All @@ -1985,16 +1990,25 @@ class CIRGenFunction : public CIRGenTypeCache {
void buildImplicitReturn();

public:
void updateCurrentSwitchCaseRegion() { CurrentSwitchRegionIdx++; }
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return RetBlocks; }
llvm::ArrayRef<std::optional<mlir::Location>> getRetLocs() {
return RetLocs;
}
llvm::MutableArrayRef<std::unique_ptr<mlir::Region>> getSwitchRegions() {
assert(isSwitch() && "expected switch scope");
return SwitchRegions;
}

mlir::Region *createSwitchRegion() {
assert(isSwitch() && "expected switch scope");
SwitchRegions.push_back(std::make_unique<mlir::Region>());
return SwitchRegions.back().get();
}

mlir::Block *getOrCreateRetBlock(CIRGenFunction &CGF, mlir::Location loc) {
unsigned int regionIdx = 0;
if (isSwitch())
regionIdx = CurrentSwitchRegionIdx;
regionIdx = SwitchRegions.size() - 1;
if (regionIdx >= RetBlocks.size())
return createRetBlock(CGF, loc);
return &*RetBlocks.back();
Expand Down
104 changes: 55 additions & 49 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
template <typename T>
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
SmallVector<mlir::Attribute, 4> &caseAttrs) {

assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
"only case or default stmt go here");
Expand All @@ -647,20 +647,18 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
// Update scope information with the current region we are
// emitting code for. This is useful to allow return blocks to be
// automatically and properly placed during cleanup.
auto *region = os.addRegion();
auto *region = currLexScope->createSwitchRegion();
auto *block = builder.createBlock(region);
builder.setInsertionPointToEnd(block);
currLexScope->updateCurrentSwitchCaseRegion();

auto *sub = stmt->getSubStmt();

if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res =
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
res = buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs);
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs);
} else {
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
}
Expand All @@ -670,27 +668,37 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(

mlir::LogicalResult
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
SmallVector<mlir::Attribute, 4> &caseAttrs) {
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
"case ranges not implemented");

auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs);
}

mlir::LogicalResult
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
SmallVector<mlir::Attribute, 4> &caseAttrs) {
auto ctxt = builder.getContext();

auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));

caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
return buildCaseDefaultCascade(&S, condType, caseAttrs);
}

mlir::LogicalResult
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
if (S.getStmtClass() == Stmt::CaseStmtClass)
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);

if (S.getStmtClass() == Stmt::DefaultStmtClass)
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);

llvm_unreachable("expect case or default stmt");
}

mlir::LogicalResult
Expand Down Expand Up @@ -953,6 +961,36 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildSwitchBody(
const Stmt *S, mlir::Type condType,
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
mlir::Block *lastCaseBlock = nullptr;
auto res = mlir::success();
for (auto *c : compoundStmt->body()) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
res = buildSwitchCase(*switchCase, condType, caseAttrs);
} else if (lastCaseBlock) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
} else {
llvm_unreachable("statement doesn't belong to any case region, NYI");
}

lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}
return res;
}

llvm_unreachable("switch body is not CompoundStmt, NYI");
}

mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
// TODO: LLVM codegen does some early optimization to fold the condition and
// only emit live cases. CIR should use MLIR to achieve similar things,
Expand All @@ -975,49 +1013,17 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
// TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
// TODO: if the switch has a condition wrapped by __builtin_unpredictable?

// FIXME: track switch to handle nested stmts.
swop = builder.create<SwitchOp>(
getLoc(S.getBeginLoc()), condV,
/*switchBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
auto *cs = dyn_cast<CompoundStmt>(S.getBody());
assert(cs && "expected compound stmt");
SmallVector<mlir::Attribute, 4> caseAttrs;

currLexScope->setAsSwitch();
mlir::Block *lastCaseBlock = nullptr;
for (auto *c : cs->body()) {
bool caseLike = isa<CaseStmt, DefaultStmt>(c);
if (!caseLike) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
assert(lastCaseBlock && "expects pre-existing case block");
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
lastCaseBlock = builder.getBlock();
if (res.failed())
break;
continue;
}

auto *caseStmt = dyn_cast<CaseStmt>(c);

if (caseStmt)
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
else {
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
assert(defaultStmt && "expected default stmt");
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
os);
}

lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}

llvm::SmallVector<mlir::Attribute, 4> caseAttrs;

res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);

os.addRegions(currLexScope->getSwitchRegions());
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
});

Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/switch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ void sw12(int a) {
break;
}
}

// CHECK: cir.func @_Z4sw12i
// CHECK: cir.scope {
// CHECK: cir.switch
Expand All @@ -275,6 +276,32 @@ void sw12(int a) {
// CHECK-NEXT: cir.break
// CHECK-NEXT: }

void sw13(int a, int b) {
switch (a) {
case 1:
switch (b) {
case 2:
break;
}
}
}

// CHECK: cir.func @_Z4sw13ii
// CHECK: cir.scope {
// CHECK: cir.switch
// CHECK-NEXT: case (equal, 1) {
// CHECK-NEXT: cir.scope {
// CHECK: cir.switch
// CHECK-NEXT: case (equal, 2) {
// CHECK-NEXT: cir.break
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK: }
// CHECK: cir.return

void fallthrough(int x) {
switch (x) {
case 1:
Expand Down

0 comments on commit f7521fd

Please sign in to comment.