Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 87 additions & 96 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,19 +562,15 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
);
}

/// Transforms the `body` of the coroutine applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// Async desugaring uses an unsafe binder type `ResumeTy` to circumvert borrow-checking.
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// The actual should be `&mut Context<'_>`. This performs the substitution:
/// - create a new local `_r` of type `ResumeTy`;
/// - assign `ResumeTy(transmute::<&mut Context<'_>, NonNull<Context<'_>>>(_2))` to that local;
/// - let all the code use `_r` instead of `_2`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
Expand All @@ -586,93 +582,90 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = Ty::new_task_context(tcx);
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
ty::GenericArgs::empty(),
body.typing_env(tcx),
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
);

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
// Replace all occurrences of `CTX_ARG` with `resume_local: ResumeTy`,
// and set `CTX_ARG: &mut Context<'_>`.
let resume_local = body.local_decls.push(LocalDecl::new(context_mut_ref, body.span));
body.local_decls.swap(CTX_ARG, resume_local);
RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
// Now `CTX_ARG` is `&mut Context` and `resume_local` is a `ResumeTy`.
// Insert a `resume_local = ResumeTy(CTX_ARG as *mut Context<'static>)`
// at the function entry to make the bridge.
let source_info = SourceInfo::outermost(body.span);
let nonnull_local = body.local_decls.push(LocalDecl::new(resume_nonnull_ty, body.span));
let nonnull_rhs =
Rvalue::Cast(CastKind::Transmute, Operand::Move(CTX_ARG.into()), resume_nonnull_ty);
let nonnull_assign = StatementKind::Assign(Box::new((nonnull_local.into(), nonnull_rhs)));
let resume_rhs = Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
resume_ty_def_id,
VariantIdx::ZERO,
ty::GenericArgs::empty(),
None,
None,
)),
indexvec![Operand::Move(nonnull_local.into())],
);
let resume_assign = StatementKind::Assign(Box::new((resume_local.into(), resume_rhs)));
body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements.splice(
0..0,
[Statement::new(source_info, nonnull_assign), Statement::new(source_info, resume_assign)],
);
}

for bb in body.basic_blocks.indices() {
let bb_data = &body[bb];
/// HIR uses `get_context` to unwrap a `&mut Context<'_>` from a `ResumeTy`.
/// Both types are just a single pointer, but liveness analysis does not know that and
/// supposes that the operand and the destination are live at the same time.
/// Forcibly inline those calls to avoid this.
fn eliminate_get_context_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = Ty::new_task_context(tcx);
let resume_ty_def_id = tcx.require_lang_item(LangItem::ResumeTy, body.span);
let resume_nonnull_ty = tcx.instantiate_and_normalize_erasing_regions(
ty::GenericArgs::empty(),
body.typing_env(tcx),
tcx.type_of(tcx.adt_def(resume_ty_def_id).non_enum_variant().fields[FieldIdx::ZERO].did),
);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
for bb_data in body.basic_blocks.as_mut().iter_mut() {
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
{
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
let terminator = bb_data.terminator_mut();
if let TerminatorKind::Call { func, args, destination, target, .. } = &terminator.kind
&& let func_ty = func.ty(&body.local_decls, tcx)
&& let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
&& let [arg] = &**args
&& let Some(place) = arg.node.place()
{
let arg =
Rvalue::Cast(
CastKind::Transmute,
Operand::Copy(place.project_deeper(
&[PlaceElem::Field(FieldIdx::ZERO, resume_nonnull_ty)],
tcx,
)),
context_mut_ref,
);
let assign = Statement::new(
terminator.source_info,
StatementKind::Assign(Box::new((*destination, arg))),
);
terminator.kind = TerminatorKind::Goto { target: target.unwrap() };
bb_data.statements.push(assign);
}
}
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
bug!();
};
let [arg] = *Box::try_from(args).unwrap();
let local = arg.node.place().unwrap().local;

let arg = Rvalue::Use(arg.node, WithRetag::Yes);
let assign =
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
}

#[cfg_attr(not(debug_assertions), allow(unused))]
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
}

/// Transforms the `body` of the coroutine applying the following transform:
///
/// - Remove the `resume` argument.
///
/// Ideally the async lowering would not add the `resume` argument.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `resume` argument for the time being. After this transform,
/// the coroutine body doesn't have the `resume` argument.
fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
// This leaves the local representing the `resume` argument in place,
// but turns it into a regular local variable. This is cheaper than
// adjusting all local references in the body after removing it.
body.arg_count = 1;
}

struct LivenessInfo {
/// Which locals are live across any suspension point.
saved_locals: CoroutineSavedLocals,
Expand Down Expand Up @@ -1293,6 +1286,10 @@ fn create_coroutine_resume_function<'tcx>(
// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, body, false);

if transform.coroutine_kind.is_async_desugaring() {
transform_async_context(tcx, body);
}

if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
dumper.dump_mir(body);
}
Expand Down Expand Up @@ -1508,12 +1505,10 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// (finally in open_drop_for_tuple) before async drop expansion.
// Async drops, produced by this drop elaboration, will be expanded,
// and corresponding futures kept in layout.
let coroutine_is_async = coroutine_kind.is_async_desugaring();
let has_async_drops = has_async_drops(body);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if coroutine_is_async {
transform_async_context(tcx, body);
if coroutine_kind.is_async_desugaring() {
eliminate_get_context_calls(tcx, body);
}

let always_live_locals = always_storage_live_locals(body);
Expand Down Expand Up @@ -1580,13 +1575,9 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
}),
);

// Update our MIR struct to reflect the changes we've made
body.arg_count = 2; // self, resume arg
body.spread_arg = None;

// Remove the context argument within generator bodies.
if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
transform_gen_context(body);
body.arg_count = 1;
}

// The original arguments to the function are no longer arguments, mark them as such.
Expand Down Expand Up @@ -1633,7 +1624,7 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);

// For coroutine with sync drop, generating async proxy for `future_drop_poll` call
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
}

Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_mir_transform/src/coroutine/drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, &mut body, false);

if transform.coroutine_kind.is_async_desugaring() {
transform_async_context(tcx, &mut body);
}

if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_async", &body) {
dumper.dump_mir(&body);
}
Expand All @@ -320,6 +324,7 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
coroutine_kind: CoroutineKind,
) -> Body<'tcx> {
let mut body = body.clone();
// Take the coroutine info out of the body, since the drop shim is
Expand Down Expand Up @@ -357,6 +362,10 @@ pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, &mut body, false);

if coroutine_kind.is_async_desugaring() {
transform_async_context(tcx, &mut body);
}

if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_proxy_async", &body) {
dumper.dump_mir(&body);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
- fn a::{closure#0}(_1: {async fn body of a()}, _2: std::future::ResumeTy) -> ()
- yields ()
- {
- debug _task_context => _2;
- let mut _0: ();
+ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) -> Poll<()> {
+ coroutine layout {
+ variant_fields = {
Expand All @@ -13,16 +15,19 @@
+ }
+ storage_conflicts = BitMatrix(0x0) {}
+ }
debug _task_context => _2;
- let mut _0: ();
+ debug _task_context => _6;
+ let mut _0: std::task::Poll<()>;
+ let mut _3: ();
+ let mut _4: u32;
+ let mut _5: &mut {async fn body of a()};
+ let mut _6: std::future::ResumeTy;
+ let mut _7: std::ptr::NonNull<std::task::Context<'_>>;

bb0: {
- _0 = const ();
- drop(_1) -> [return: bb1, unwind: bb2];
+ _7 = move _2 as std::ptr::NonNull<std::task::Context<'_>> (Transmute);
+ _6 = std::future::ResumeTy(move _7);
+ _5 = copy (_1.0: &mut {async fn body of a()});
+ _4 = discriminant((*_5));
+ switchInt(move _4) -> [0: bb5, 1: bb3, otherwise: bb4];
Expand Down
Loading
Loading