Skip to content

Commit

Permalink
address
Browse files Browse the repository at this point in the history
  • Loading branch information
wenpen committed Apr 20, 2024
1 parent 01566f1 commit c495036
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 30 deletions.
5 changes: 2 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1994,9 +1994,8 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Block *getEntryBlock() { return EntryBlock; }

mlir::Location BeginLoc, EndLoc;
// Each SmallVector<APSInt> object is corresponding to a case region, empty
// vector means default case region.
llvm::SmallVector<llvm::SmallVector<llvm::APSInt>> caseEltValueLists;
mlir::Type switchCondType;
llvm::SmallVector<mlir::Attribute, 4> caseAttrs;
mlir::Block *lastCaseBlock = nullptr;
};

Expand Down
51 changes: 24 additions & 27 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,16 +608,26 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
const CaseStmt *CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S) {
const CaseStmt *caseStmt = &S;
const CaseStmt *lastCase = &S;
SmallVector<llvm::APSInt> caseEltValueList;
const auto &condType = currLexScope->switchCondType;
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;

// Fold cascading cases whenever possible to simplify codegen a bit.
while (caseStmt) {
lastCase = caseStmt;
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
caseEltValueList.push_back(intVal);
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
}
currLexScope->caseEltValueLists.push_back(std::move(caseEltValueList));

auto *ctxt = builder.getContext();

auto caseAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal));

currLexScope->caseAttrs.push_back(caseAttr);

return lastCase;
}
Expand Down Expand Up @@ -661,7 +671,13 @@ mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S) {
}

mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S) {
currLexScope->caseEltValueLists.push_back({});
auto ctxt = builder.getContext();

auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));

currLexScope->caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S);
}

Expand Down Expand Up @@ -989,32 +1005,13 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
/*switchBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
currLexScope->setAsSwitch();
currLexScope->switchCondType = condV.getType();

res = buildSwitchBody(S.getBody());

os.addRegions(currLexScope->getSwitchRegions());

SmallVector<mlir::Attribute, 4> caseAttrs;
auto ctxt = builder.getContext();
for (const auto &caseEltValueList : currLexScope->caseEltValueLists) {
if (caseEltValueList.empty()) {
auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
caseAttrs.push_back(defAttr);
} else {
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
for (auto intVal : caseEltValueList)
caseEltValueListAttr.push_back(
mlir::cir::IntAttr::get(condV.getType(), intVal));
caseAttrs.push_back(mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt,
caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal)));
}
}
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
os.addAttribute("cases",
builder.getArrayAttr(currLexScope->caseAttrs));
});

if (res.failed())
Expand Down

0 comments on commit c495036

Please sign in to comment.