rustc_mir_transform/coroutine/
by_move_body.rs

1//! This pass constructs a second coroutine body sufficient for return from
2//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures).
3//!
4//! Consider an async closure like:
5//! ```rust
6//! let x = vec![1, 2, 3];
7//!
8//! let closure = async move || {
9//!     println!("{x:#?}");
10//! };
11//! ```
12//!
13//! This desugars to something like:
14//! ```rust,ignore (invalid-borrowck)
15//! let x = vec![1, 2, 3];
16//!
17//! let closure = move || {
18//!     async {
19//!         println!("{x:#?}");
20//!     }
21//! };
22//! ```
23//!
24//! Important to note here is that while the outer closure *moves* `x: Vec<i32>`
25//! into its upvars, the inner `async` coroutine simply captures a ref of `x`.
26//! This is the "magic" of async closures -- the futures that they return are
27//! allowed to borrow from their parent closure's upvars.
28//!
29//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`,
30//! since all async closures implement that too)? Well, recall the signature:
31//! ```
32//! use std::future::Future;
33//! pub trait AsyncFnOnce<Args>
34//! {
35//!     type CallOnceFuture: Future<Output = Self::Output>;
36//!     type Output;
37//!     fn async_call_once(
38//!         self,
39//!         args: Args
40//!     ) -> Self::CallOnceFuture;
41//! }
42//! ```
43//!
44//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`.
45//! How do we deal with the fact that the coroutine is supposed to take a reference
46//! to the captured `x` from the parent closure, when that parent closure has been
47//! destroyed?
48//!
49//! This is the second piece of magic of async closures. We can simply create a
50//! *second* `async` coroutine body where that `x` that was previously captured
51//! by reference is now captured by value. This means that we consume the outer
52//! closure and return a new coroutine that will hold onto all of these captures,
53//! and drop them when it is finished (i.e. after it has been `.await`ed).
54//!
55//! We do this with the analysis below, which detects the captures that come from
56//! borrowing from the outer closure, and we simply peel off a `deref` projection
57//! from them. This second body is stored alongside the first body, and optimized
58//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
59//! we use this "by-move" body instead.
60//!
61//! ## How does this work?
62//!
63//! This pass essentially remaps the body of the (child) closure of the coroutine-closure
64//! to take the set of upvars of the parent closure by value. This at least requires
65//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure
66//! captures something by value; however, it may also require renumbering field indices
67//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine
68//! to split one field capture into two.
69
70use rustc_abi::{FieldIdx, VariantIdx};
71use rustc_data_structures::steal::Steal;
72use rustc_data_structures::unord::UnordMap;
73use rustc_hir as hir;
74use rustc_hir::def::DefKind;
75use rustc_hir::def_id::{DefId, LocalDefId};
76use rustc_hir::definitions::DisambiguatorState;
77use rustc_middle::bug;
78use rustc_middle::hir::place::{Projection, ProjectionKind};
79use rustc_middle::mir::visit::MutVisitor;
80use rustc_middle::mir::{self, dump_mir};
81use rustc_middle::ty::{self, InstanceKind, Ty, TyCtxt, TypeVisitableExt};
82
83pub(crate) fn coroutine_by_move_body_def_id<'tcx>(
84    tcx: TyCtxt<'tcx>,
85    coroutine_def_id: LocalDefId,
86) -> DefId {
87    let body = tcx.mir_built(coroutine_def_id).borrow();
88
89    // If the typeck results are tainted, no need to make a by-ref body.
90    if body.tainted_by_errors.is_some() {
91        return coroutine_def_id.to_def_id();
92    }
93
94    let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
95        tcx.coroutine_kind(coroutine_def_id)
96    else {
97        bug!("should only be invoked on coroutine-closures");
98    };
99
100    // Also, let's skip processing any bodies with errors, since there's no guarantee
101    // the MIR body will be constructed well.
102    let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
103
104    let ty::Coroutine(_, args) = *coroutine_ty.kind() else {
105        bug!("tried to create by-move body of non-coroutine receiver");
106    };
107    let args = args.as_coroutine();
108
109    let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
110
111    let parent_def_id = tcx.local_parent(coroutine_def_id);
112    let ty::CoroutineClosure(_, parent_args) =
113        *tcx.type_of(parent_def_id).instantiate_identity().kind()
114    else {
115        bug!("coroutine's parent was not a coroutine-closure");
116    };
117    if parent_args.references_error() {
118        return coroutine_def_id.to_def_id();
119    }
120
121    let parent_closure_args = parent_args.as_coroutine_closure();
122    let num_args = parent_closure_args
123        .coroutine_closure_sig()
124        .skip_binder()
125        .tupled_inputs_ty
126        .tuple_fields()
127        .len();
128
129    let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures(
130        tcx.closure_captures(parent_def_id).iter().copied(),
131        tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(),
132        |(parent_field_idx, parent_capture), (child_field_idx, child_capture)| {
133            // Store this set of additional projections (fields and derefs).
134            // We need to re-apply them later.
135            let mut child_precise_captures =
136                child_capture.place.projections[parent_capture.place.projections.len()..].to_vec();
137
138            // If the parent capture is by-ref, then we need to apply an additional
139            // deref before applying any further projections to this place.
140            if parent_capture.is_by_ref() {
141                child_precise_captures.insert(
142                    0,
143                    Projection { ty: parent_capture.place.ty(), kind: ProjectionKind::Deref },
144                );
145            }
146            // If the child capture is by-ref, then we need to apply a "ref"
147            // projection (i.e. `&`) at the end. But wait! We don't have that
148            // as a projection kind. So instead, we can apply its dual and
149            // *peel* a deref off of the place when it shows up in the MIR body.
150            // Luckily, by construction this is always possible.
151            let peel_deref = if child_capture.is_by_ref() {
152                assert!(
153                    parent_capture.is_by_ref() || coroutine_kind != ty::ClosureKind::FnOnce,
154                    "`FnOnce` coroutine-closures return coroutines that capture from \
155                        their body; it will always result in a borrowck error!"
156                );
157                true
158            } else {
159                false
160            };
161
162            // Regarding the behavior above, you may think that it's redundant to both
163            // insert a deref and then peel a deref if the parent and child are both
164            // captured by-ref. This would be correct, except for the case where we have
165            // precise capturing projections, since the inserted deref is to the *beginning*
166            // and the peeled deref is at the *end*. I cannot seem to actually find a
167            // case where this happens, though, but let's keep this code flexible.
168
169            // Finally, store the type of the parent's captured place. We need
170            // this when building the field projection in the MIR body later on.
171            let mut parent_capture_ty = parent_capture.place.ty();
172            parent_capture_ty = match parent_capture.info.capture_kind {
173                ty::UpvarCapture::ByValue | ty::UpvarCapture::ByUse => parent_capture_ty,
174                ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
175                    tcx,
176                    tcx.lifetimes.re_erased,
177                    parent_capture_ty,
178                    kind.to_mutbl_lossy(),
179                ),
180            };
181
182            Some((
183                FieldIdx::from_usize(child_field_idx + num_args),
184                (
185                    FieldIdx::from_usize(parent_field_idx + num_args),
186                    parent_capture_ty,
187                    peel_deref,
188                    child_precise_captures,
189                ),
190            ))
191        },
192    )
193    .flatten()
194    .collect();
195
196    if coroutine_kind == ty::ClosureKind::FnOnce {
197        assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
198        // The by-move body is just the body :)
199        return coroutine_def_id.to_def_id();
200    }
201
202    let by_move_coroutine_ty = tcx
203        .instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig())
204        .to_coroutine_given_kind_and_upvars(
205            tcx,
206            parent_closure_args.parent_args(),
207            coroutine_def_id.to_def_id(),
208            ty::ClosureKind::FnOnce,
209            tcx.lifetimes.re_erased,
210            parent_closure_args.tupled_upvars_ty(),
211            parent_closure_args.coroutine_captures_by_ref_ty(),
212        );
213
214    let mut by_move_body = body.clone();
215    MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
216
217    // This path is unique since we're in a query so we'll only be called once with `parent_def_id`
218    // and this is the only location creating `SyntheticCoroutineBody`.
219    let body_def = tcx.create_def(
220        parent_def_id,
221        None,
222        DefKind::SyntheticCoroutineBody,
223        None,
224        &mut DisambiguatorState::new(),
225    );
226    by_move_body.source =
227        mir::MirSource::from_instance(InstanceKind::Item(body_def.def_id().to_def_id()));
228    dump_mir(tcx, false, "built", &"after", &by_move_body, |_, _| Ok(()));
229
230    // Feed HIR because we try to access this body's attrs in the inliner.
231    body_def.feed_hir();
232    // Inherited from the by-ref coroutine.
233    body_def.codegen_fn_attrs(tcx.codegen_fn_attrs(coroutine_def_id).clone());
234    body_def.coverage_attr_on(tcx.coverage_attr_on(coroutine_def_id));
235    body_def.constness(tcx.constness(coroutine_def_id));
236    body_def.coroutine_kind(tcx.coroutine_kind(coroutine_def_id));
237    body_def.def_ident_span(tcx.def_ident_span(coroutine_def_id));
238    body_def.def_span(tcx.def_span(coroutine_def_id));
239    body_def.explicit_predicates_of(tcx.explicit_predicates_of(coroutine_def_id));
240    body_def.generics_of(tcx.generics_of(coroutine_def_id).clone());
241    body_def.param_env(tcx.param_env(coroutine_def_id));
242    body_def.predicates_of(tcx.predicates_of(coroutine_def_id));
243
244    // The type of the coroutine is the `by_move_coroutine_ty`.
245    body_def.type_of(ty::EarlyBinder::bind(by_move_coroutine_ty));
246
247    body_def.mir_built(tcx.arena.alloc(Steal::new(by_move_body)));
248
249    body_def.def_id().to_def_id()
250}
251
252struct MakeByMoveBody<'tcx> {
253    tcx: TyCtxt<'tcx>,
254    field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, Vec<Projection<'tcx>>)>,
255    by_move_coroutine_ty: Ty<'tcx>,
256}
257
258impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
259    fn tcx(&self) -> TyCtxt<'tcx> {
260        self.tcx
261    }
262
263    fn visit_place(
264        &mut self,
265        place: &mut mir::Place<'tcx>,
266        context: mir::visit::PlaceContext,
267        location: mir::Location,
268    ) {
269        // Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a
270        // field projection. If this is in `field_remapping`, then it must not be an
271        // arg from calling the closure, but instead an upvar.
272        if place.local == ty::CAPTURE_STRUCT_LOCAL
273            && let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
274                place.projection.split_first()
275            && let Some(&(remapped_idx, remapped_ty, peel_deref, ref bridging_projections)) =
276                self.field_remapping.get(&idx)
277        {
278            // As noted before, if the parent closure captures a field by value, and
279            // the child captures a field by ref, then for the by-move body we're
280            // generating, we also are taking that field by value. Peel off a deref,
281            // since a layer of ref'ing has now become redundant.
282            let final_projections = if peel_deref {
283                let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first()
284                else {
285                    bug!(
286                        "There should be at least a single deref for an upvar local initialization, found {projection:#?}"
287                    );
288                };
289                // There may be more derefs, since we may also implicitly reborrow
290                // a captured mut pointer.
291                projection
292            } else {
293                projection
294            };
295
296            // These projections are applied in order to "bridge" the local that we are
297            // currently transforming *from* the old upvar that the by-ref coroutine used
298            // to capture *to* the upvar of the parent coroutine-closure. For example, if
299            // the parent captures `&s` but the child captures `&(s.field)`, then we will
300            // apply a field projection.
301            let bridging_projections = bridging_projections.iter().map(|elem| match elem.kind {
302                ProjectionKind::Deref => mir::ProjectionElem::Deref,
303                ProjectionKind::Field(idx, VariantIdx::ZERO) => {
304                    mir::ProjectionElem::Field(idx, elem.ty)
305                }
306                _ => unreachable!("precise captures only through fields and derefs"),
307            });
308
309            // We start out with an adjusted field index (and ty), representing the
310            // upvar that we get from our parent closure. We apply any of the additional
311            // projections to make sure that to the rest of the body of the closure, the
312            // place looks the same, and then apply that final deref if necessary.
313            *place = mir::Place {
314                local: place.local,
315                projection: self.tcx.mk_place_elems_from_iter(
316                    [mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
317                        .into_iter()
318                        .chain(bridging_projections)
319                        .chain(final_projections.iter().copied()),
320                ),
321            };
322        }
323        self.super_place(place, context, location);
324    }
325
326    fn visit_statement(&mut self, statement: &mut mir::Statement<'tcx>, location: mir::Location) {
327        // Remove fake borrows of closure captures if that capture has been
328        // replaced with a by-move version of that capture.
329        //
330        // For example, imagine we capture `Foo` in the parent and `&Foo`
331        // in the child. We will emit two fake borrows like:
332        //
333        // ```
334        //    _2 = &fake shallow (*(_1.0: &Foo));
335        //    _3 = &fake shallow (_1.0: &Foo);
336        // ```
337        //
338        // However, since this transform is responsible for replacing
339        // `_1.0: &Foo` with `_1.0: Foo`, that makes the second fake borrow
340        // obsolete, and we should replace it with a nop.
341        //
342        // As a side-note, we don't actually even care about fake borrows
343        // here at all since they're fully a MIR borrowck artifact, and we
344        // don't need to borrowck by-move MIR bodies. But it's best to preserve
345        // as much as we can between these two bodies :)
346        if let mir::StatementKind::Assign(box (_, rvalue)) = &statement.kind
347            && let mir::Rvalue::Ref(_, mir::BorrowKind::Fake(mir::FakeBorrowKind::Shallow), place) =
348                rvalue
349            && let mir::PlaceRef {
350                local: ty::CAPTURE_STRUCT_LOCAL,
351                projection: [mir::ProjectionElem::Field(idx, _)],
352            } = place.as_ref()
353            && let Some(&(_, _, true, _)) = self.field_remapping.get(&idx)
354        {
355            statement.kind = mir::StatementKind::Nop;
356        }
357
358        self.super_statement(statement, location);
359    }
360
361    fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
362        // Replace the type of the self arg.
363        if local == ty::CAPTURE_STRUCT_LOCAL {
364            local_decl.ty = self.by_move_coroutine_ty;
365        }
366        self.super_local_decl(local, local_decl);
367    }
368}