From c4950363715fbef5e3f2771e9ab415171cc4ef4b Mon Sep 17 00:00:00 2001 From: axp Date: Sat, 20 Apr 2024 11:05:41 +0800 Subject: [PATCH] address --- clang/lib/CIR/CodeGen/CIRGenFunction.h | 5 +-- clang/lib/CIR/CodeGen/CIRGenStmt.cpp | 51 ++++++++++++-------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 57193376bb1f..76df9abd200d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1994,9 +1994,8 @@ class CIRGenFunction : public CIRGenTypeCache { mlir::Block *getEntryBlock() { return EntryBlock; } mlir::Location BeginLoc, EndLoc; - // Each SmallVector object is corresponding to a case region, empty - // vector means default case region. - llvm::SmallVector> caseEltValueLists; + mlir::Type switchCondType; + llvm::SmallVector caseAttrs; mlir::Block *lastCaseBlock = nullptr; }; diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 17b9aca3eb68..13a63f02db8d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -608,16 +608,26 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) { const CaseStmt *CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S) { const CaseStmt *caseStmt = &S; const CaseStmt *lastCase = &S; - SmallVector caseEltValueList; + const auto &condType = currLexScope->switchCondType; + SmallVector caseEltValueListAttr; // Fold cascading cases whenever possible to simplify codegen a bit. while (caseStmt) { lastCase = caseStmt; auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext()); - caseEltValueList.push_back(intVal); + caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal)); caseStmt = dyn_cast_or_null(caseStmt->getSubStmt()); } - currLexScope->caseEltValueLists.push_back(std::move(caseEltValueList)); + + auto *ctxt = builder.getContext(); + + auto caseAttr = mlir::cir::CaseAttr::get( + ctxt, builder.getArrayAttr(caseEltValueListAttr), + CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1 + ? mlir::cir::CaseOpKind::Anyof + : mlir::cir::CaseOpKind::Equal)); + + currLexScope->caseAttrs.push_back(caseAttr); return lastCase; } @@ -661,7 +671,13 @@ mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S) { } mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S) { - currLexScope->caseEltValueLists.push_back({}); + auto ctxt = builder.getContext(); + + auto defAttr = mlir::cir::CaseAttr::get( + ctxt, builder.getArrayAttr({}), + CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default)); + + currLexScope->caseAttrs.push_back(defAttr); return buildCaseDefaultCascade(&S); } @@ -989,32 +1005,13 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) { /*switchBuilder=*/ [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) { currLexScope->setAsSwitch(); + currLexScope->switchCondType = condV.getType(); + res = buildSwitchBody(S.getBody()); os.addRegions(currLexScope->getSwitchRegions()); - - SmallVector caseAttrs; - auto ctxt = builder.getContext(); - for (const auto &caseEltValueList : currLexScope->caseEltValueLists) { - if (caseEltValueList.empty()) { - auto defAttr = mlir::cir::CaseAttr::get( - ctxt, builder.getArrayAttr({}), - CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default)); - caseAttrs.push_back(defAttr); - } else { - SmallVector caseEltValueListAttr; - for (auto intVal : caseEltValueList) - caseEltValueListAttr.push_back( - mlir::cir::IntAttr::get(condV.getType(), intVal)); - caseAttrs.push_back(mlir::cir::CaseAttr::get( - ctxt, builder.getArrayAttr(caseEltValueListAttr), - CaseOpKindAttr::get(ctxt, - caseEltValueListAttr.size() > 1 - ? mlir::cir::CaseOpKind::Anyof - : mlir::cir::CaseOpKind::Equal))); - } - } - os.addAttribute("cases", builder.getArrayAttr(caseAttrs)); + os.addAttribute("cases", + builder.getArrayAttr(currLexScope->caseAttrs)); }); if (res.failed())