diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 9ece64bc89bb..22b863ba0d7c 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1085,18 +1085,23 @@ class CIRGenFunction : public CIRGenTypeCache { template mlir::LogicalResult buildCaseDefaultCascade(const T *stmt, mlir::Type condType, - SmallVector &caseAttrs, - mlir::OperationState &os); + SmallVector &caseAttrs); mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S, mlir::Type condType, - SmallVector &caseAttrs, - mlir::OperationState &op); + SmallVector &caseAttrs); mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType, - SmallVector &caseAttrs, - mlir::OperationState &op); + SmallVector &caseAttrs); + + mlir::LogicalResult + buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType, + SmallVector &caseAttrs); + + mlir::LogicalResult + buildSwitchBody(const clang::Stmt *S, mlir::Type condType, + SmallVector &caseAttrs); mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn, const CIRGenFunctionInfo &FnInfo); @@ -1964,7 +1969,7 @@ class CIRGenFunction : public CIRGenTypeCache { // have their own scopes but are distinct regions nonetheless. llvm::SmallVector RetBlocks; llvm::SmallVector> RetLocs; - unsigned int CurrentSwitchRegionIdx = -1; + llvm::SmallVector> SwitchRegions; // There's usually only one ret block per scope, but this needs to be // get or create because of potential unreachable return statements, note @@ -1985,16 +1990,25 @@ class CIRGenFunction : public CIRGenTypeCache { void buildImplicitReturn(); public: - void updateCurrentSwitchCaseRegion() { CurrentSwitchRegionIdx++; } llvm::ArrayRef getRetBlocks() { return RetBlocks; } llvm::ArrayRef> getRetLocs() { return RetLocs; } + llvm::MutableArrayRef> getSwitchRegions() { + assert(isSwitch() && "expected switch scope"); + return SwitchRegions; + } + + mlir::Region *createSwitchRegion() { + assert(isSwitch() && "expected switch scope"); + SwitchRegions.push_back(std::make_unique()); + return SwitchRegions.back().get(); + } mlir::Block *getOrCreateRetBlock(CIRGenFunction &CGF, mlir::Location loc) { unsigned int regionIdx = 0; if (isSwitch()) - regionIdx = CurrentSwitchRegionIdx; + regionIdx = SwitchRegions.size() - 1; if (regionIdx >= RetBlocks.size()) return createRetBlock(CGF, loc); return &*RetBlocks.back(); diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 17264d36e588..981804892ebb 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -637,7 +637,7 @@ CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType, template mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade( const T *stmt, mlir::Type condType, - SmallVector &caseAttrs, mlir::OperationState &os) { + SmallVector &caseAttrs) { assert((isa(stmt)) && "only case or default stmt go here"); @@ -647,20 +647,18 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade( // Update scope information with the current region we are // emitting code for. This is useful to allow return blocks to be // automatically and properly placed during cleanup. - auto *region = os.addRegion(); + auto *region = currLexScope->createSwitchRegion(); auto *block = builder.createBlock(region); builder.setInsertionPointToEnd(block); - currLexScope->updateCurrentSwitchCaseRegion(); auto *sub = stmt->getSubStmt(); if (isa(sub) && isa(stmt)) { builder.createYield(getLoc(stmt->getBeginLoc())); - res = - buildDefaultStmt(*dyn_cast(sub), condType, caseAttrs, os); + res = buildDefaultStmt(*dyn_cast(sub), condType, caseAttrs); } else if (isa(sub) && isa(stmt)) { builder.createYield(getLoc(stmt->getBeginLoc())); - res = buildCaseStmt(*dyn_cast(sub), condType, caseAttrs, os); + res = buildCaseStmt(*dyn_cast(sub), condType, caseAttrs); } else { res = buildStmt(sub, /*useCurrentScope=*/!isa(sub)); } @@ -670,19 +668,17 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade( mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType, - SmallVector &caseAttrs, - mlir::OperationState &os) { + SmallVector &caseAttrs) { assert((!S.getRHS() || !S.caseStmtIsGNURange()) && "case ranges not implemented"); auto *caseStmt = foldCaseStmt(S, condType, caseAttrs); - return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os); + return buildCaseDefaultCascade(caseStmt, condType, caseAttrs); } mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType, - SmallVector &caseAttrs, - mlir::OperationState &os) { + SmallVector &caseAttrs) { auto ctxt = builder.getContext(); auto defAttr = mlir::cir::CaseAttr::get( @@ -690,7 +686,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType, CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default)); caseAttrs.push_back(defAttr); - return buildCaseDefaultCascade(&S, condType, caseAttrs, os); + return buildCaseDefaultCascade(&S, condType, caseAttrs); +} + +mlir::LogicalResult +CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType, + SmallVector &caseAttrs) { + if (S.getStmtClass() == Stmt::CaseStmtClass) + return buildCaseStmt(cast(S), condType, caseAttrs); + + if (S.getStmtClass() == Stmt::DefaultStmtClass) + return buildDefaultStmt(cast(S), condType, caseAttrs); + + llvm_unreachable("expect case or default stmt"); } mlir::LogicalResult @@ -953,6 +961,36 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) { return mlir::success(); } +mlir::LogicalResult CIRGenFunction::buildSwitchBody( + const Stmt *S, mlir::Type condType, + llvm::SmallVector &caseAttrs) { + if (auto *compoundStmt = dyn_cast(S)) { + mlir::Block *lastCaseBlock = nullptr; + auto res = mlir::success(); + for (auto *c : compoundStmt->body()) { + if (auto *switchCase = dyn_cast(c)) { + res = buildSwitchCase(*switchCase, condType, caseAttrs); + } else if (lastCaseBlock) { + // This means it's a random stmt following up a case, just + // emit it as part of previous known case. + mlir::OpBuilder::InsertionGuard guardCase(builder); + builder.setInsertionPointToEnd(lastCaseBlock); + res = buildStmt(c, /*useCurrentScope=*/!isa(c)); + } else { + llvm_unreachable("statement doesn't belong to any case region, NYI"); + } + + lastCaseBlock = builder.getBlock(); + + if (res.failed()) + break; + } + return res; + } + + llvm_unreachable("switch body is not CompoundStmt, NYI"); +} + mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) { // TODO: LLVM codegen does some early optimization to fold the condition and // only emit live cases. CIR should use MLIR to achieve similar things, @@ -975,49 +1013,17 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) { // TODO: PGO and likelihood (e.g. PGO.haveRegionCounts()) // TODO: if the switch has a condition wrapped by __builtin_unpredictable? - // FIXME: track switch to handle nested stmts. swop = builder.create( getLoc(S.getBeginLoc()), condV, /*switchBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) { - auto *cs = dyn_cast(S.getBody()); - assert(cs && "expected compound stmt"); - SmallVector caseAttrs; - currLexScope->setAsSwitch(); - mlir::Block *lastCaseBlock = nullptr; - for (auto *c : cs->body()) { - bool caseLike = isa(c); - if (!caseLike) { - // This means it's a random stmt following up a case, just - // emit it as part of previous known case. - assert(lastCaseBlock && "expects pre-existing case block"); - mlir::OpBuilder::InsertionGuard guardCase(builder); - builder.setInsertionPointToEnd(lastCaseBlock); - res = buildStmt(c, /*useCurrentScope=*/!isa(c)); - lastCaseBlock = builder.getBlock(); - if (res.failed()) - break; - continue; - } - - auto *caseStmt = dyn_cast(c); - - if (caseStmt) - res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os); - else { - auto *defaultStmt = dyn_cast(c); - assert(defaultStmt && "expected default stmt"); - res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs, - os); - } - - lastCaseBlock = builder.getBlock(); - - if (res.failed()) - break; - } + llvm::SmallVector caseAttrs; + + res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs); + + os.addRegions(currLexScope->getSwitchRegions()); os.addAttribute("cases", builder.getArrayAttr(caseAttrs)); }); diff --git a/clang/test/CIR/CodeGen/switch.cpp b/clang/test/CIR/CodeGen/switch.cpp index 3c63e4ea4820..b378c7364475 100644 --- a/clang/test/CIR/CodeGen/switch.cpp +++ b/clang/test/CIR/CodeGen/switch.cpp @@ -266,6 +266,7 @@ void sw12(int a) { break; } } + // CHECK: cir.func @_Z4sw12i // CHECK: cir.scope { // CHECK: cir.switch @@ -275,6 +276,32 @@ void sw12(int a) { // CHECK-NEXT: cir.break // CHECK-NEXT: } +void sw13(int a, int b) { + switch (a) { + case 1: + switch (b) { + case 2: + break; + } + } +} + +// CHECK: cir.func @_Z4sw13ii +// CHECK: cir.scope { +// CHECK: cir.switch +// CHECK-NEXT: case (equal, 1) { +// CHECK-NEXT: cir.scope { +// CHECK: cir.switch +// CHECK-NEXT: case (equal, 2) { +// CHECK-NEXT: cir.break +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: cir.yield +// CHECK-NEXT: } +// CHECK: } +// CHECK: cir.return + void fallthrough(int x) { switch (x) { case 1: