Skip to content

Commit

Permalink
refactor build switch op
Browse files Browse the repository at this point in the history
  • Loading branch information
wenpen committed Apr 19, 2024
1 parent b7e1c76 commit 01566f1
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 90 deletions.
41 changes: 24 additions & 17 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1079,24 +1079,18 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::Type getCIRType(const clang::QualType &type);

const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S);

template <typename T>
mlir::LogicalResult
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os);
mlir::LogicalResult buildCaseDefaultCascade(const T *stmt);

mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);
mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S);

mlir::LogicalResult
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &op);
mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S);

mlir::LogicalResult buildSwitchCase(const clang::SwitchCase &S);

mlir::LogicalResult buildSwitchBody(const clang::Stmt *S);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down Expand Up @@ -1951,7 +1945,7 @@ class CIRGenFunction : public CIRGenTypeCache {
// have their own scopes but are distinct regions nonetheless.
llvm::SmallVector<mlir::Block *> RetBlocks;
llvm::SmallVector<std::optional<mlir::Location>> RetLocs;
unsigned int CurrentSwitchRegionIdx = -1;
llvm::SmallVector<std::unique_ptr<mlir::Region>> SwitchRegions;

// There's usually only one ret block per scope, but this needs to be
// get or create because of potential unreachable return statements, note
Expand All @@ -1972,16 +1966,25 @@ class CIRGenFunction : public CIRGenTypeCache {
void buildImplicitReturn();

public:
void updateCurrentSwitchCaseRegion() { CurrentSwitchRegionIdx++; }
llvm::ArrayRef<mlir::Block *> getRetBlocks() { return RetBlocks; }
llvm::ArrayRef<std::optional<mlir::Location>> getRetLocs() {
return RetLocs;
}
llvm::MutableArrayRef<std::unique_ptr<mlir::Region>> getSwitchRegions() {
assert(isSwitch() && "expected switch scope");
return SwitchRegions;
}

mlir::Region *createSwitchRegion() {
assert(isSwitch() && "expected switch scope");
SwitchRegions.push_back(std::make_unique<mlir::Region>());
return SwitchRegions.back().get();
}

mlir::Block *getOrCreateRetBlock(CIRGenFunction &CGF, mlir::Location loc) {
unsigned int regionIdx = 0;
if (isSwitch())
regionIdx = CurrentSwitchRegionIdx;
regionIdx = SwitchRegions.size() - 1;
if (regionIdx >= RetBlocks.size())
return createRetBlock(CGF, loc);
return &*RetBlocks.back();
Expand All @@ -1991,6 +1994,10 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Block *getEntryBlock() { return EntryBlock; }

mlir::Location BeginLoc, EndLoc;
// Each SmallVector<APSInt> object is corresponding to a case region, empty
// vector means default case region.
llvm::SmallVector<llvm::SmallVector<llvm::APSInt>> caseEltValueLists;
mlir::Block *lastCaseBlock = nullptr;
};

LexicalScope *currLexScope = nullptr;
Expand Down
143 changes: 70 additions & 73 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,38 +605,25 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
return mlir::success();
}

const CaseStmt *
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
const CaseStmt *CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S) {
const CaseStmt *caseStmt = &S;
const CaseStmt *lastCase = &S;
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
SmallVector<llvm::APSInt> caseEltValueList;

// Fold cascading cases whenever possible to simplify codegen a bit.
while (caseStmt) {
lastCase = caseStmt;
auto intVal = caseStmt->getLHS()->EvaluateKnownConstInt(getContext());
caseEltValueListAttr.push_back(mlir::cir::IntAttr::get(condType, intVal));
caseEltValueList.push_back(intVal);
caseStmt = dyn_cast_or_null<CaseStmt>(caseStmt->getSubStmt());
}

auto *ctxt = builder.getContext();

auto caseAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt, caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal));

caseAttrs.push_back(caseAttr);
currLexScope->caseEltValueLists.push_back(std::move(caseEltValueList));

return lastCase;
}

template <typename T>
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs, mlir::OperationState &os) {
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(const T *stmt) {

assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
"only case or default stmt go here");
Expand All @@ -646,50 +633,46 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
// Update scope information with the current region we are
// emitting code for. This is useful to allow return blocks to be
// automatically and properly placed during cleanup.
auto *region = os.addRegion();
auto *region = currLexScope->createSwitchRegion();
auto *block = builder.createBlock(region);
builder.setInsertionPointToEnd(block);
currLexScope->updateCurrentSwitchCaseRegion();

auto *sub = stmt->getSubStmt();

if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res =
buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs, os);
res = buildDefaultStmt(*dyn_cast<DefaultStmt>(sub));
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs, os);
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub));
} else {
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
}

return res;
}

mlir::LogicalResult
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S) {
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
"case ranges not implemented");

auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs, os);
auto *caseStmt = foldCaseStmt(S);
return buildCaseDefaultCascade(caseStmt);
}

mlir::LogicalResult
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs,
mlir::OperationState &os) {
auto ctxt = builder.getContext();
mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S) {
currLexScope->caseEltValueLists.push_back({});
return buildCaseDefaultCascade(&S);
}

mlir::LogicalResult CIRGenFunction::buildSwitchCase(const SwitchCase &S) {
if (S.getStmtClass() == Stmt::CaseStmtClass)
return buildCaseStmt(cast<CaseStmt>(S));

auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
if (S.getStmtClass() == Stmt::DefaultStmtClass)
return buildDefaultStmt(cast<DefaultStmt>(S));

caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
llvm_unreachable("expect case or default stmt");
}

mlir::LogicalResult
Expand Down Expand Up @@ -952,6 +935,33 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildSwitchBody(const Stmt *S) {
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
auto res = mlir::success();
for (auto *c : compoundStmt->body()) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
res = buildSwitchCase(*switchCase);
} else if (currLexScope->lastCaseBlock) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(currLexScope->lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
} else {
llvm_unreachable("statement doesn't belong to any case region, NYI");
}

currLexScope->lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}
return res;
}

llvm_unreachable("switch body is not CompoundStmt, NYI");
}

mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
// TODO: LLVM codegen does some early optimization to fold the condition and
// only emit live cases. CIR should use MLIR to achieve similar things,
Expand All @@ -974,49 +984,36 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
// TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
// TODO: if the switch has a condition wrapped by __builtin_unpredictable?

// FIXME: track switch to handle nested stmts.
swop = builder.create<SwitchOp>(
getLoc(S.getBeginLoc()), condV,
/*switchBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
auto *cs = dyn_cast<CompoundStmt>(S.getBody());
assert(cs && "expected compound stmt");
SmallVector<mlir::Attribute, 4> caseAttrs;

currLexScope->setAsSwitch();
mlir::Block *lastCaseBlock = nullptr;
for (auto *c : cs->body()) {
bool caseLike = isa<CaseStmt, DefaultStmt>(c);
if (!caseLike) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
assert(lastCaseBlock && "expects pre-existing case block");
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
lastCaseBlock = builder.getBlock();
if (res.failed())
break;
continue;
}
res = buildSwitchBody(S.getBody());

auto *caseStmt = dyn_cast<CaseStmt>(c);
os.addRegions(currLexScope->getSwitchRegions());

if (caseStmt)
res = buildCaseStmt(*caseStmt, condV.getType(), caseAttrs, os);
else {
auto *defaultStmt = dyn_cast<DefaultStmt>(c);
assert(defaultStmt && "expected default stmt");
res = buildDefaultStmt(*defaultStmt, condV.getType(), caseAttrs,
os);
SmallVector<mlir::Attribute, 4> caseAttrs;
auto ctxt = builder.getContext();
for (const auto &caseEltValueList : currLexScope->caseEltValueLists) {
if (caseEltValueList.empty()) {
auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));
caseAttrs.push_back(defAttr);
} else {
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;
for (auto intVal : caseEltValueList)
caseEltValueListAttr.push_back(
mlir::cir::IntAttr::get(condV.getType(), intVal));
caseAttrs.push_back(mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr(caseEltValueListAttr),
CaseOpKindAttr::get(ctxt,
caseEltValueListAttr.size() > 1
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal)));
}

lastCaseBlock = builder.getBlock();

if (res.failed())
break;
}

os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
});

Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/CodeGen/switch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ void sw12(int a) {
break;
}
}

// CHECK: cir.func @_Z4sw12i
// CHECK: cir.scope {
// CHECK: cir.switch
Expand All @@ -275,6 +276,32 @@ void sw12(int a) {
// CHECK-NEXT: cir.break
// CHECK-NEXT: }

void sw13(int a, int b) {
switch (a) {
case 1:
switch (b) {
case 2:
break;
}
}
}

// CHECK: cir.func @_Z4sw13ii
// CHECK: cir.scope {
// CHECK: cir.switch
// CHECK-NEXT: case (equal, 1) {
// CHECK-NEXT: cir.scope {
// CHECK: cir.switch
// CHECK-NEXT: case (equal, 2) {
// CHECK-NEXT: cir.break
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }
// CHECK: }
// CHECK: cir.return

void fallthrough(int x) {
switch (x) {
case 1:
Expand Down

0 comments on commit 01566f1

Please sign in to comment.