Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR[CIRGen][NFC] Refactor build switch op #552

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,18 +1085,23 @@ class CIRGenFunction : public CIRGenTypeCache {
template <typename T>
mlir::LogicalResult
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os);
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down Expand Up @@ -1951,7 +1956,7 @@ class CIRGenFunction : public CIRGenTypeCache {
// have their own scopes but are distinct regions nonetheless.
llvm::SmallVector<mlir::Block *> RetBlocks;
llvm::SmallVector<std::optional<mlir::Location>> RetLocs;
unsigned int CurrentSwitchRegionIdx = -1;
llvm::SmallVector<std::unique_ptr<mlir::Region>> SwitchRegions;

// There's usually only one ret block per scope, but this needs to be
// get or create because of potential unreachable return statements, note
Expand All @@ -1972,16 +1977,25 @@ class CIRGenFunction : public CIRGenTypeCache {
void buildImplicitReturn();

public:
void updateCurrentSwitchCaseRegion() { CurrentSwitchRegionIdx++; }
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return RetBlocks; }
llvm::ArrayRef<std::optional<mlir::Location>> getRetLocs() {
return RetLocs;
}
llvm::MutableArrayRef<std::unique_ptr<mlir::Region>> getSwitchRegions() {
assert(isSwitch() && "expected switch scope");
return SwitchRegions;
}

mlir::Region *createSwitchRegion() {
assert(isSwitch() && "expected switch scope");
SwitchRegions.push_back(std::make_unique<mlir::Region>());
return SwitchRegions.back().get();
}

mlir::Block *getOrCreateRetBlock(CIRGenFunction &CGF, mlir::Location loc) {
unsigned int regionIdx = 0;
if (isSwitch())
regionIdx = CurrentSwitchRegionIdx;
regionIdx = SwitchRegions.size() - 1;
if (regionIdx >= RetBlocks.size())
return createRetBlock(CGF, loc);
return &*RetBlocks.back();
Expand Down
104 changes: 55 additions & 49 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
template <typename T>
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
SmallVector<mlir::Attribute, 4> &caseAttrs) {

assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
"only case or default stmt go here");
Expand All @@ -646,20 +646,18 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
// Update scope information with the current region we are
// emitting code for. This is useful to allow return blocks to be
// automatically and properly placed during cleanup.
auto *region = os.addRegion();
auto *region = currLexScope->createSwitchRegion();
auto *block = builder.createBlock(region);
builder.setInsertionPointToEnd(block);
currLexScope->updateCurrentSwitchCaseRegion();

auto *sub = stmt->getSubStmt();

if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res =
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
res = buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs);
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs);
} else {
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
}
Expand All @@ -669,27 +667,37 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(

mlir::LogicalResult
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
SmallVector<mlir::Attribute, 4> &caseAttrs) {
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
"case ranges not implemented");

auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs);
}

mlir::LogicalResult
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
SmallVector<mlir::Attribute, 4> &caseAttrs) {
auto ctxt = builder.getContext();

auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));

caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
return buildCaseDefaultCascade(&S, condType, caseAttrs);
}

mlir::LogicalResult
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
if (S.getStmtClass() == Stmt::CaseStmtClass)
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);

if (S.getStmtClass() == Stmt::DefaultStmtClass)
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);

llvm_unreachable("expect case or default stmt");
}

mlir::LogicalResult
Expand Down Expand Up @@ -952,6 +960,36 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildSwitchBody(
const Stmt *S, mlir::Type condType,
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
mlir::Block *lastCaseBlock = nullptr;
auto res = mlir::success();
for (auto *c : compoundStmt->body()) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
res = buildSwitchCase(*switchCase, condType, caseAttrs);
} else if (lastCaseBlock) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
} else {
llvm_unreachable("statement doesn't belong to any case region, NYI");
}

lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}
return res;
}

llvm_unreachable("switch body is not CompoundStmt, NYI");
}

mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
// TODO: LLVM codegen does some early optimization to fold the condition and
// only emit live cases. CIR should use MLIR to achieve similar things,
Expand All @@ -974,49 +1012,17 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
// TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
// TODO: if the switch has a condition wrapped by __builtin_unpredictable?

// FIXME: track switch to handle nested stmts.
swop = builder.create<SwitchOp>(
getLoc(S.getBeginLoc()), condV,
/*switchBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
auto *cs = dyn_cast<CompoundStmt>(S.getBody());
assert(cs && "expected compound stmt");
SmallVector<mlir::Attribute, 4> caseAttrs;

currLexScope->setAsSwitch();
mlir::Block *lastCaseBlock = nullptr;
for (auto *c : cs->body()) {
bool caseLike = isa<CaseStmt, DefaultStmt>(c);
if (!caseLike) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
assert(lastCaseBlock && "expects pre-existing case block");
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
lastCaseBlock = builder.getBlock();
if (res.failed())
break;
continue;
}

auto *caseStmt = dyn_cast<CaseStmt>(c);

if (caseStmt)
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
else {
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
assert(defaultStmt && "expected default stmt");
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
os);
}

lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}

llvm::SmallVector<mlir::Attribute, 4> caseAttrs;

res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);

os.addRegions(currLexScope->getSwitchRegions());
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
});

Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/switch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ void sw12(int a) {
break;
}
}

// CHECK: cir.func @_Z4sw12i
// CHECK: cir.scope {
// CHECK: cir.switch
Expand All @@ -275,6 +276,32 @@ void sw12(int a) {
// CHECK-NEXT: cir.break
// CHECK-NEXT: }

void sw13(int a, int b) {
switch (a) {
case 1:
switch (b) {
case 2:
break;
}
}
}

// CHECK: cir.func @_Z4sw13ii
// CHECK: cir.scope {
// CHECK: cir.switch
// CHECK-NEXT: case (equal, 1) {
// CHECK-NEXT: cir.scope {
// CHECK: cir.switch
// CHECK-NEXT: case (equal, 2) {
// CHECK-NEXT: cir.break
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK: }
// CHECK: cir.return

void fallthrough(int x) {
switch (x) {
case 1:
Expand Down
Loading