Skip to content

Commit

Permalink
revert adding variables in LexScope
Browse files Browse the repository at this point in the history
  • Loading branch information
wenpen committed Apr 21, 2024
1 parent c495036 commit 879ba09
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
26 changes: 17 additions & 9 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1079,18 +1079,29 @@ class CIRGenFunction : public CIRGenTypeCache {

mlir::Type getCIRType(const clang::QualType &type);

const CaseStmt *foldCaseStmt(const clang::CaseStmt &S);
const CaseStmt *foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

template <typename T>
mlir::LogicalResult buildCaseDefaultCascade(const T *stmt);
mlir::LogicalResult
buildCaseDefaultCascade(const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S);
mlir::LogicalResult buildCaseStmt(const clang::CaseStmt &S,
mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult buildDefaultStmt(const clang::DefaultStmt &S);
mlir::LogicalResult
buildDefaultStmt(const clang::DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult buildSwitchCase(const clang::SwitchCase &S);
mlir::LogicalResult
buildSwitchCase(const clang::SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::LogicalResult buildSwitchBody(const clang::Stmt *S);
mlir::LogicalResult
buildSwitchBody(const clang::Stmt *S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs);

mlir::cir::FuncOp generateCode(clang::GlobalDecl GD, mlir::cir::FuncOp Fn,
const CIRGenFunctionInfo &FnInfo);
Expand Down Expand Up @@ -1994,9 +2005,6 @@ class CIRGenFunction : public CIRGenTypeCache {
mlir::Block *getEntryBlock() { return EntryBlock; }

mlir::Location BeginLoc, EndLoc;
mlir::Type switchCondType;
llvm::SmallVector<mlir::Attribute, 4> caseAttrs;
mlir::Block *lastCaseBlock = nullptr;
};

LexicalScope *currLexScope = nullptr;
Expand Down
60 changes: 36 additions & 24 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,11 @@ mlir::LogicalResult CIRGenFunction::buildBreakStmt(const clang::BreakStmt &S) {
return mlir::success();
}

const CaseStmt *CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S) {
const CaseStmt *
CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
const CaseStmt *caseStmt = &S;
const CaseStmt *lastCase = &S;
const auto &condType = currLexScope->switchCondType;
SmallVector<mlir::Attribute, 4> caseEltValueListAttr;

// Fold cascading cases whenever possible to simplify codegen a bit.
Expand All @@ -627,13 +628,15 @@ const CaseStmt *CIRGenFunction::foldCaseStmt(const clang::CaseStmt &S) {
? mlir::cir::CaseOpKind::Anyof
: mlir::cir::CaseOpKind::Equal));

currLexScope->caseAttrs.push_back(caseAttr);
caseAttrs.push_back(caseAttr);

return lastCase;
}

template <typename T>
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(const T *stmt) {
mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(
const T *stmt, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {

assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
"only case or default stmt go here");
Expand All @@ -651,42 +654,48 @@ mlir::LogicalResult CIRGenFunction::buildCaseDefaultCascade(const T *stmt) {

if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res = buildDefaultStmt(*dyn_cast<DefaultStmt>(sub));
res = buildDefaultStmt(*dyn_cast<DefaultStmt>(sub), condType, caseAttrs);
} else if (isa<CaseStmt>(sub) && isa<DefaultStmt>(stmt)) {
builder.createYield(getLoc(stmt->getBeginLoc()));
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub));
res = buildCaseStmt(*dyn_cast<CaseStmt>(sub), condType, caseAttrs);
} else {
res = buildStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
}

return res;
}

mlir::LogicalResult CIRGenFunction::buildCaseStmt(const CaseStmt &S) {
mlir::LogicalResult
CIRGenFunction::buildCaseStmt(const CaseStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
assert((!S.getRHS() || !S.caseStmtIsGNURange()) &&
"case ranges not implemented");

auto *caseStmt = foldCaseStmt(S);
return buildCaseDefaultCascade(caseStmt);
auto *caseStmt = foldCaseStmt(S, condType, caseAttrs);
return buildCaseDefaultCascade(caseStmt, condType, caseAttrs);
}

mlir::LogicalResult CIRGenFunction::buildDefaultStmt(const DefaultStmt &S) {
mlir::LogicalResult
CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
auto ctxt = builder.getContext();

auto defAttr = mlir::cir::CaseAttr::get(
ctxt, builder.getArrayAttr({}),
CaseOpKindAttr::get(ctxt, mlir::cir::CaseOpKind::Default));

currLexScope->caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S);
caseAttrs.push_back(defAttr);
return buildCaseDefaultCascade(&S, condType, caseAttrs);
}

mlir::LogicalResult CIRGenFunction::buildSwitchCase(const SwitchCase &S) {
mlir::LogicalResult
CIRGenFunction::buildSwitchCase(const SwitchCase &S, mlir::Type condType,
SmallVector<mlir::Attribute, 4> &caseAttrs) {
if (S.getStmtClass() == Stmt::CaseStmtClass)
return buildCaseStmt(cast<CaseStmt>(S));
return buildCaseStmt(cast<CaseStmt>(S), condType, caseAttrs);

if (S.getStmtClass() == Stmt::DefaultStmtClass)
return buildDefaultStmt(cast<DefaultStmt>(S));
return buildDefaultStmt(cast<DefaultStmt>(S), condType, caseAttrs);

llvm_unreachable("expect case or default stmt");
}
Expand Down Expand Up @@ -951,23 +960,26 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
return mlir::success();
}

mlir::LogicalResult CIRGenFunction::buildSwitchBody(const Stmt *S) {
mlir::LogicalResult CIRGenFunction::buildSwitchBody(
const Stmt *S, mlir::Type condType,
llvm::SmallVector<mlir::Attribute, 4> &caseAttrs) {
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
mlir::Block *lastCaseBlock = nullptr;
auto res = mlir::success();
for (auto *c : compoundStmt->body()) {
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
res = buildSwitchCase(*switchCase);
} else if (currLexScope->lastCaseBlock) {
res = buildSwitchCase(*switchCase, condType, caseAttrs);
} else if (lastCaseBlock) {
// This means it's a random stmt following up a case, just
// emit it as part of previous known case.
mlir::OpBuilder::InsertionGuard guardCase(builder);
builder.setInsertionPointToEnd(currLexScope->lastCaseBlock);
builder.setInsertionPointToEnd(lastCaseBlock);
res = buildStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c));
} else {
llvm_unreachable("statement doesn't belong to any case region, NYI");
}

currLexScope->lastCaseBlock = builder.getBlock();
lastCaseBlock = builder.getBlock();

if (res.failed())
break;
Expand Down Expand Up @@ -1005,13 +1017,13 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
/*switchBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
currLexScope->setAsSwitch();
currLexScope->switchCondType = condV.getType();

res = buildSwitchBody(S.getBody());
llvm::SmallVector<mlir::Attribute, 4> caseAttrs;

res = buildSwitchBody(S.getBody(), condV.getType(), caseAttrs);

os.addRegions(currLexScope->getSwitchRegions());
os.addAttribute("cases",
builder.getArrayAttr(currLexScope->caseAttrs));
os.addAttribute("cases", builder.getArrayAttr(caseAttrs));
});

if (res.failed())
Expand Down

0 comments on commit 879ba09

Please sign in to comment.