Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CodeGen] Handle the case of 'case' after label statement after 'case' #879

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,13 @@ class CIRGenFunction : public CIRGenTypeCache {
// applies to. nullptr if there is no 'musttail' on the current statement.
const clang::CallExpr *MustTailCall = nullptr;

/// The attributes of cases collected during emitting the body of a switch
/// stmt.
llvm::SmallVector<llvm::SmallVector<mlir::Attribute, 4>, 2> caseAttrsStack;

/// The type of the condition for the emitting switch statement.
llvm::SmallVector<mlir::Type, 2> condTypeStack;

clang::ASTContext &getContext() const;

CIRGenBuilderTy &getBuilder() { return builder; }
Expand Down Expand Up @@ -1210,13 +1217,9 @@ class CIRGenFunction : public CIRGenTypeCache {
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);
mlir::LogicalResult buildSwitchCase(const clang::SwitchCase &S);

mlir::LogicalResult
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);
mlir::LogicalResult buildSwitchBody(const clang::Stmt *S);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down
34 changes: 20 additions & 14 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,

case Stmt::CaseStmtClass:
case Stmt::DefaultStmtClass:
assert(0 &&
"Should not get here, currently handled directly from SwitchStmt");
return buildSwitchCase(cast<SwitchCase>(*S));
break;

case Stmt::BreakStmtClass:
Expand Down Expand Up @@ -715,14 +714,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
return buildCaseDefaultCascade(&S, condType, caseAttrs);
}

mlir::LogicalResult
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
mlir::LogicalResult CIRGenFunction::buildSwitchCase(const SwitchCase &S) {
assert(!caseAttrsStack.empty() &&
"build switch case without seeting case attrs");
assert(!condTypeStack.empty() &&
"build switch case without specifying the type of the condition");

if (S.getStmtClass() == Stmt::CaseStmtClass)
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);
return buildCaseStmt(cast<CaseStmt>(S), condTypeStack.back(),
caseAttrsStack.back());

if (S.getStmtClass() == Stmt::DefaultStmtClass)
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);
return buildDefaultStmt(cast<DefaultStmt>(S), condTypeStack.back(),
caseAttrsStack.back());

llvm_unreachable("expect case or default stmt");
}
Expand Down Expand Up @@ -987,15 +991,13 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildSwitchBody(
const Stmt *S, mlir::Type condType,
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
mlir::LogicalResult CIRGenFunction::buildSwitchBody(const Stmt *S) {
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
mlir::Block *lastCaseBlock = nullptr;
auto res = mlir::success();
for (auto *c : compoundStmt->body()) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
res = buildSwitchCase(*switchCase, condType, caseAttrs);
res = buildSwitchCase(*switchCase);
lastCaseBlock = builder.getBlock();
} else if (lastCaseBlock) {
// This means it's a random stmt following up a case, just
Expand Down Expand Up @@ -1045,12 +1047,16 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
currLexScope->setAsSwitch();

llvm::SmallVector<mlir::Attribute, 4> caseAttrs;
caseAttrsStack.push_back({});
condTypeStack.push_back(condV.getType());

res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);
res = buildSwitchBody(S.getBody());

os.addRegions(currLexScope->getSwitchRegions());
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
os.addAttribute("cases", builder.getArrayAttr(caseAttrsStack.back()));

caseAttrsStack.pop_back();
condTypeStack.pop_back();
});

if (res.failed())
Expand Down
48 changes: 48 additions & 0 deletions clang/test/CIR/CodeGen/goto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,51 @@ extern "C" void multiple_non_case(int v) {
// NOFLAT: cir.label
// NOFLAT: cir.call @action2()
// NOFLAT: cir.break

extern "C" void case_follow_label(int v) {
switch (v) {
case 1:
label:
case 2:
action1();
break;
default:
action2();
goto label;
}
}

// NOFLAT: cir.func @case_follow_label
// NOFLAT: cir.switch
// NOFLAT: case (equal, 1)
// NOFLAT: cir.label "label"
// NOFLAT: cir.yield
// NOFLAT: case (equal, 2)
// NOFLAT: cir.call @action1()
// NOFLAT: cir.break
// NOFLAT: case (default)
// NOFLAT: cir.call @action2()
// NOFLAT: cir.goto "label"

extern "C" void default_follow_label(int v) {
switch (v) {
case 1:
case 2:
action1();
break;
label:
default:
action2();
goto label;
}
}

// NOFLAT: cir.func @default_follow_label
// NOFLAT: cir.switch
// NOFLAT: case (anyof, [1, 2] : !s32i)
// NOFLAT: cir.call @action1()
// NOFLAT: cir.break
// NOFLAT: cir.label "label"
// NOFLAT: case (default)
// NOFLAT: cir.call @action2()
// NOFLAT: cir.goto "label"