Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework coroutine transform to be more flexible in preparation for async generators #118418

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 123 additions & 80 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::dump_mir;
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::CoroutineArgs;
use rustc_middle::ty::InstanceDef;
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
use rustc_middle::ty::{CoroutineArgs, GenericArgsRef};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_mir_dataflow::impls::{
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
};
Expand Down Expand Up @@ -226,8 +226,6 @@ struct SuspensionPoint<'tcx> {
struct TransformVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
coroutine_kind: hir::CoroutineKind,
state_adt_ref: AdtDef<'tcx>,
state_args: GenericArgsRef<'tcx>,

// The type of the discriminant in the coroutine struct
discr_ty: Ty<'tcx>,
Expand All @@ -246,21 +244,34 @@ struct TransformVisitor<'tcx> {
always_live_locals: BitSet<Local>,

// The original RETURN_PLACE local
new_ret_local: Local,
old_ret_local: Local,

old_yield_ty: Ty<'tcx>,

old_ret_ty: Ty<'tcx>,
}

impl<'tcx> TransformVisitor<'tcx> {
fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
let block = BasicBlock::new(body.basic_blocks.len());
assert!(matches!(self.coroutine_kind, CoroutineKind::Gen(_)));

let block = BasicBlock::new(body.basic_blocks.len());
let source_info = SourceInfo::outermost(body.span);
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);

let (kind, idx) = self.coroutine_state_adt_and_variant_idx(true);
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
let statements = vec![Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::from_usize(0),
self.tcx.mk_args(&[self.old_yield_ty.into()]),
None,
None,
)),
IndexVec::new(),
),
))),
source_info,
}];
Expand All @@ -274,23 +285,6 @@ impl<'tcx> TransformVisitor<'tcx> {
block
}

fn coroutine_state_adt_and_variant_idx(
&self,
is_return: bool,
) -> (AggregateKind<'tcx>, VariantIdx) {
let idx = VariantIdx::new(match (is_return, self.coroutine_kind) {
(true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete
(false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded
(true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready
(false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending
(true, hir::CoroutineKind::Gen(_)) => 0, // Option::None
(false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some
});

let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None);
(kind, idx)
}

// Make a `CoroutineState` or `Poll` variant assignment.
//
// `core::ops::CoroutineState` only has single element tuple variants,
Expand All @@ -303,51 +297,99 @@ impl<'tcx> TransformVisitor<'tcx> {
is_return: bool,
statements: &mut Vec<Statement<'tcx>>,
) {
let (kind, idx) = self.coroutine_state_adt_and_variant_idx(is_return);

match self.coroutine_kind {
// `Poll::Pending`
let rvalue = match self.coroutine_kind {
CoroutineKind::Async(_) => {
if !is_return {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);

// FIXME(swatinem): assert that `val` is indeed unit?
statements.push(Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
))),
source_info,
});
return;
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
if is_return {
// Poll::Ready(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
poll_def_id,
VariantIdx::from_usize(0),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
} else {
// Poll::Pending
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
poll_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::new(),
)
}
}
// `Option::None`
CoroutineKind::Gen(_) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
if is_return {
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);

statements.push(Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
))),
source_info,
});
return;
// None
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::from_usize(0),
args,
None,
None,
)),
IndexVec::new(),
)
} else {
// Some(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
}
}
CoroutineKind::Coroutine => {}
}

// else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)`
assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
CoroutineKind::Coroutine => {
let coroutine_state_def_id =
self.tcx.require_lang_item(LangItem::CoroutineState, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
if is_return {
// CoroutineState::Complete(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
coroutine_state_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
} else {
// CoroutineState::Yielded(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
coroutine_state_def_id,
VariantIdx::from_usize(0),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
}
}
};

statements.push(Statement {
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Aggregate(Box::new(kind), [val].into()),
))),
kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
source_info,
});
}
Expand Down Expand Up @@ -421,7 +463,7 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {

let ret_val = match data.terminator().kind {
TerminatorKind::Return => {
Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
Some((true, None, Operand::Move(Place::from(self.old_ret_local)), None))
}
TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
Some((false, Some((resume, resume_arg)), value.clone(), drop))
Expand Down Expand Up @@ -1503,10 +1545,11 @@ pub(crate) fn mir_coroutine_witnesses<'tcx>(

impl<'tcx> MirPass<'tcx> for StateTransform {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let Some(yield_ty) = body.yield_ty() else {
let Some(old_yield_ty) = body.yield_ty() else {
// This only applies to coroutines
return;
};
let old_ret_ty = body.return_ty();

assert!(body.coroutine_drop().is_none());

Expand All @@ -1528,34 +1571,33 @@ impl<'tcx> MirPass<'tcx> for StateTransform {

let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_)));
let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_)));
let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() {
let new_ret_ty = match body.coroutine_kind().unwrap() {
CoroutineKind::Async(_) => {
// Compute Poll<return_ty>
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_args = tcx.mk_args(&[body.return_ty().into()]);
(poll_adt_ref, poll_args)
let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
Ty::new_adt(tcx, poll_adt_ref, poll_args)
}
CoroutineKind::Gen(_) => {
// Compute Option<yield_ty>
let option_did = tcx.require_lang_item(LangItem::Option, None);
let option_adt_ref = tcx.adt_def(option_did);
let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]);
(option_adt_ref, option_args)
let option_args = tcx.mk_args(&[old_yield_ty.into()]);
Ty::new_adt(tcx, option_adt_ref, option_args)
}
CoroutineKind::Coroutine => {
// Compute CoroutineState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::CoroutineState, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]);
(state_adt_ref, state_args)
let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
Ty::new_adt(tcx, state_adt_ref, state_args)
}
};
let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args);

// We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
// We rename RETURN_PLACE which has type mir.return_ty to old_ret_local
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if is_async_kind {
Expand All @@ -1572,17 +1614,18 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
} else {
body.local_decls[resume_local].ty
};
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);

// When first entering the coroutine, move the resume argument into its new local.
// When first entering the coroutine, move the resume argument into its old local
// (which is now a generator interior).
let source_info = SourceInfo::outermost(body.span);
let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
stmts.insert(
0,
Statement {
source_info,
kind: StatementKind::Assign(Box::new((
new_resume_local.into(),
old_resume_local.into(),
Rvalue::Use(Operand::Move(resume_local.into())),
))),
},
Expand Down Expand Up @@ -1618,14 +1661,14 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
let mut transform = TransformVisitor {
tcx,
coroutine_kind: body.coroutine_kind().unwrap(),
state_adt_ref,
state_args,
remap,
storage_liveness,
always_live_locals,
suspension_points: Vec::new(),
new_ret_local,
old_ret_local,
discr_ty,
old_ret_ty,
old_yield_ty,
};
transform.visit_body(body);

Expand Down
Loading