Skip to content

Commit

Permalink
[CIR][CIRGen] Add CIRGen for scalar co_yield expression (#761)
Browse files Browse the repository at this point in the history
This PR adds CIRGen for scalar `co_yield` expressions.
  • Loading branch information
Lancern authored Aug 1, 2024
1 parent cc01a56 commit 1674254
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 12 deletions.
8 changes: 5 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3147,12 +3147,13 @@ def TryCallOp : CIR_CallOp<"try_call",

def AK_Initial : I32EnumAttrCase<"init", 1>;
def AK_User : I32EnumAttrCase<"user", 2>;
def AK_Final : I32EnumAttrCase<"final", 3>;
def AK_Yield : I32EnumAttrCase<"yield", 3>;
def AK_Final : I32EnumAttrCase<"final", 4>;

def AwaitKind : I32EnumAttr<
"AwaitKind",
"await kind",
[AK_Initial, AK_User, AK_Final]> {
[AK_Initial, AK_User, AK_Yield, AK_Final]> {
let cppNamespace = "::mlir::cir";
}

Expand Down Expand Up @@ -3186,9 +3187,10 @@ def AwaitOp : CIR_Op<"await",
of CIR, e.g. LLVM, should use the `suspend` region to track more
lower level codegen (e.g. intrinsic emission for coro.save/coro.suspend).

There are also 3 flavors of `cir.await` available:
There are also 4 flavors of `cir.await` available:
- `init`: compiler generated initial suspend via implicit `co_await`.
- `user`: also known as normal, representing user written co_await's.
- `yield`: user written `co_yield` expressions.
- `final`: compiler generated final suspend via implicit `co_await`.

From the C++ snippet we get:
Expand Down
30 changes: 22 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,25 +492,25 @@ buildSuspendExpression(CIRGenFunction &CGF, CGCoroData &Coro,
return awaitRes;
}

RValue CIRGenFunction::buildCoawaitExpr(const CoawaitExpr &E,
AggValueSlot aggSlot,
bool ignoreResult) {
static RValue buildSuspendExpr(CIRGenFunction &CGF,
const CoroutineSuspendExpr &E,
mlir::cir::AwaitKind kind, AggValueSlot aggSlot,
bool ignoreResult) {
RValue rval;
auto scopeLoc = getLoc(E.getSourceRange());
auto scopeLoc = CGF.getLoc(E.getSourceRange());

// Since we model suspend / resume as an inner region, we must store
// resume scalar results in a tmp alloca, and load it after we build the
// suspend expression. An alternative way to do this would be to make
// every region return a value when promise.return_value() is used, but
// it's a bit awkward given that resume is the only region that actually
// returns a value.
mlir::Block *currEntryBlock = currLexScope->getEntryBlock();
mlir::Block *currEntryBlock = CGF.currLexScope->getEntryBlock();
[[maybe_unused]] mlir::Value tmpResumeRValAddr;

// No need to explicitly wrap this into a scope since the AST already uses a
// ExprWithCleanups, which will wrap this into a cir.scope anyways.
rval = buildSuspendExpression(*this, *CurCoro.Data, E,
CurCoro.Data->CurrentAwaitKind, aggSlot,
rval = buildSuspendExpression(CGF, *CGF.CurCoro.Data, E, kind, aggSlot,
ignoreResult, currEntryBlock, tmpResumeRValAddr,
/*forLValue*/ false)
.RV;
Expand All @@ -519,7 +519,7 @@ RValue CIRGenFunction::buildCoawaitExpr(const CoawaitExpr &E,
return rval;

if (rval.isScalar()) {
rval = RValue::get(builder.create<mlir::cir::LoadOp>(
rval = RValue::get(CGF.getBuilder().create<mlir::cir::LoadOp>(
scopeLoc, rval.getScalarVal().getType(), tmpResumeRValAddr));
} else if (rval.isAggregate()) {
// This is probably already handled via AggSlot, remove this assertion
Expand All @@ -531,6 +531,20 @@ RValue CIRGenFunction::buildCoawaitExpr(const CoawaitExpr &E,
return rval;
}

RValue CIRGenFunction::buildCoawaitExpr(const CoawaitExpr &E,
AggValueSlot aggSlot,
bool ignoreResult) {
return buildSuspendExpr(*this, E, CurCoro.Data->CurrentAwaitKind, aggSlot,
ignoreResult);
}

RValue CIRGenFunction::buildCoyieldExpr(const CoyieldExpr &E,
AggValueSlot aggSlot,
bool ignoreResult) {
return buildSuspendExpr(*this, E, mlir::cir::AwaitKind::yield, aggSlot,
ignoreResult);
}

mlir::LogicalResult CIRGenFunction::buildCoreturnStmt(CoreturnStmt const &S) {
++CurCoro.Data->CoreturnCount;
currLexScope->setCoreturn();
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value VisitCoawaitExpr(CoawaitExpr *S) {
return CGF.buildCoawaitExpr(*S).getScalarVal();
}
mlir::Value VisitCoyieldExpr(CoyieldExpr *S) { llvm_unreachable("NYI"); }
mlir::Value VisitCoyieldExpr(CoyieldExpr *S) {
return CGF.buildCoyieldExpr(*S).getScalarVal();
}
mlir::Value VisitUnaryCoawait(const UnaryOperator *E) {
llvm_unreachable("NYI");
}
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,9 @@ class CIRGenFunction : public CIRGenTypeCache {
RValue buildCoawaitExpr(const CoawaitExpr &E,
AggValueSlot aggSlot = AggValueSlot::ignored(),
bool ignoreResult = false);
RValue buildCoyieldExpr(const CoyieldExpr &E,
AggValueSlot aggSlot = AggValueSlot::ignored(),
bool ignoreResult = false);
RValue buildCoroutineIntrinsic(const CallExpr *E, unsigned int IID);
RValue buildCoroutineFrame();

Expand Down
49 changes: 49 additions & 0 deletions clang/test/CIR/CodeGen/coro-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,52 @@ folly::coro::Task<int> go4() {
// CHECK: }, resume : {
// CHECK: },)
// CHECK: }

folly::coro::Task<void> yield();
folly::coro::Task<void> yield1() {
auto t = yield();
co_yield t;
}

// CHECK: cir.func coroutine @_Z6yield1v() -> !ty_22folly3A3Acoro3A3ATask3Cvoid3E22

// CHECK: cir.await(init, ready : {
// CHECK: }, suspend : {
// CHECK: }, resume : {
// CHECK: },)

// CHECK: cir.scope {
// CHECK-NEXT: %[[#SUSPEND_PTR:]] = cir.alloca !ty_22std3A3Asuspend_always22, !cir.ptr<!ty_22std3A3Asuspend_always22>
// CHECK-NEXT: %[[#AWAITER_PTR:]] = cir.alloca !ty_22folly3A3Acoro3A3ATask3Cvoid3E22, !cir.ptr<!ty_22folly3A3Acoro3A3ATask3Cvoid3E22>
// CHECK-NEXT: %[[#CORO_PTR:]] = cir.alloca !ty_22std3A3Acoroutine_handle3Cvoid3E22, !cir.ptr<!ty_22std3A3Acoroutine_handle3Cvoid3E22>
// CHECK-NEXT: %[[#CORO2_PTR:]] = cir.alloca !ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22, !cir.ptr<!ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22>
// CHECK-NEXT: cir.call @_ZN5folly4coro4TaskIvEC1ERKS2_(%[[#AWAITER_PTR]], %{{.+}}) : (!cir.ptr<!ty_22folly3A3Acoro3A3ATask3Cvoid3E22>, !cir.ptr<!ty_22folly3A3Acoro3A3ATask3Cvoid3E22>) -> ()
// CHECK-NEXT: %[[#AWAITER:]] = cir.load %[[#AWAITER_PTR]] : !cir.ptr<!ty_22folly3A3Acoro3A3ATask3Cvoid3E22>, !ty_22folly3A3Acoro3A3ATask3Cvoid3E22
// CHECK-NEXT: %[[#SUSPEND:]] = cir.call @_ZN5folly4coro4TaskIvE12promise_type11yield_valueES2_(%{{.+}}, %[[#AWAITER]]) : (!cir.ptr<!ty_22folly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type22>, !ty_22folly3A3Acoro3A3ATask3Cvoid3E22) -> !ty_22std3A3Asuspend_always22
// CHECK-NEXT: cir.store %[[#SUSPEND]], %[[#SUSPEND_PTR]] : !ty_22std3A3Asuspend_always22, !cir.ptr<!ty_22std3A3Asuspend_always22>
// CHECK-NEXT: cir.await(yield, ready : {
// CHECK-NEXT: %[[#READY:]] = cir.scope {
// CHECK-NEXT: %[[#A:]] = cir.call @_ZNSt14suspend_always11await_readyEv(%[[#SUSPEND_PTR]]) : (!cir.ptr<!ty_22std3A3Asuspend_always22>) -> !cir.bool
// CHECK-NEXT: cir.yield %[[#A]] : !cir.bool
// CHECK-NEXT: } : !cir.bool
// CHECK-NEXT: cir.condition(%[[#READY]])
// CHECK-NEXT: }, suspend : {
// CHECK-NEXT: %[[#CORO2:]] = cir.call @_ZNSt16coroutine_handleIN5folly4coro4TaskIvE12promise_typeEE12from_addressEPv(%9) : (!cir.ptr<!void>) -> !ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22
// CHECK-NEXT: cir.store %[[#CORO2]], %[[#CORO2_PTR]] : !ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22, !cir.ptr<!ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22>
// CHECK-NEXT: %[[#B:]] = cir.load %[[#CORO2_PTR]] : !cir.ptr<!ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22>, !ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22
// CHECK-NEXT: cir.call @_ZNSt16coroutine_handleIvEC1IN5folly4coro4TaskIvE12promise_typeEEES_IT_E(%[[#CORO_PTR]], %[[#B]]) : (!cir.ptr<!ty_22std3A3Acoroutine_handle3Cvoid3E22>, !ty_22std3A3Acoroutine_handle3Cfolly3A3Acoro3A3ATask3Cvoid3E3A3Apromise_type3E22) -> ()
// CHECK-NEXT: %[[#C:]] = cir.load %[[#CORO_PTR]] : !cir.ptr<!ty_22std3A3Acoroutine_handle3Cvoid3E22>, !ty_22std3A3Acoroutine_handle3Cvoid3E22
// CHECK-NEXT: cir.call @_ZNSt14suspend_always13await_suspendESt16coroutine_handleIvE(%[[#SUSPEND_PTR]], %[[#C]]) : (!cir.ptr<!ty_22std3A3Asuspend_always22>, !ty_22std3A3Acoroutine_handle3Cvoid3E22) -> ()
// CHECK-NEXT: cir.yield
// CHECK-NEXT: }, resume : {
// CHECK-NEXT: cir.call @_ZNSt14suspend_always12await_resumeEv(%[[#SUSPEND_PTR]]) : (!cir.ptr<!ty_22std3A3Asuspend_always22>) -> ()
// CHECK-NEXT: cir.yield
// CHECK-NEXT: },)
// CHECK-NEXT: }

// CHECK: cir.await(final, ready : {
// CHECK: }, suspend : {
// CHECK: }, resume : {
// CHECK: },)

// CHECK: }

0 comments on commit 1674254

Please sign in to comment.