1use super::*;
4
5struct FixReturnPendingVisitor<'tcx> {
7 tcx: TyCtxt<'tcx>,
8}
9
10impl<'tcx> MutVisitor<'tcx> for FixReturnPendingVisitor<'tcx> {
11 fn tcx(&self) -> TyCtxt<'tcx> {
12 self.tcx
13 }
14
15 fn visit_assign(
16 &mut self,
17 place: &mut Place<'tcx>,
18 rvalue: &mut Rvalue<'tcx>,
19 _location: Location,
20 ) {
21 if place.local != RETURN_PLACE {
22 return;
23 }
24
25 if let Rvalue::Aggregate(kind, _) = rvalue {
27 if let AggregateKind::Adt(_, _, ref mut args, _, _) = **kind {
28 *args = self.tcx.mk_args(&[self.tcx.types.unit.into()]);
29 }
30 }
31 }
32}
33
34fn build_poll_call<'tcx>(
36 tcx: TyCtxt<'tcx>,
37 body: &mut Body<'tcx>,
38 poll_unit_place: &Place<'tcx>,
39 switch_block: BasicBlock,
40 fut_pin_place: &Place<'tcx>,
41 fut_ty: Ty<'tcx>,
42 context_ref_place: &Place<'tcx>,
43 unwind: UnwindAction,
44) -> BasicBlock {
45 let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, None);
46 let poll_fn = Ty::new_fn_def(tcx, poll_fn, [fut_ty]);
47 let poll_fn = Operand::Constant(Box::new(ConstOperand {
48 span: DUMMY_SP,
49 user_ty: None,
50 const_: Const::zero_sized(poll_fn),
51 }));
52 let call = TerminatorKind::Call {
53 func: poll_fn.clone(),
54 args: [
55 dummy_spanned(Operand::Move(*fut_pin_place)),
56 dummy_spanned(Operand::Move(*context_ref_place)),
57 ]
58 .into(),
59 destination: *poll_unit_place,
60 target: Some(switch_block),
61 unwind,
62 call_source: CallSource::Misc,
63 fn_span: DUMMY_SP,
64 };
65 insert_term_block(body, call)
66}
67
68fn build_pin_fut<'tcx>(
70 tcx: TyCtxt<'tcx>,
71 body: &mut Body<'tcx>,
72 fut_place: Place<'tcx>,
73 unwind: UnwindAction,
74) -> (BasicBlock, Place<'tcx>) {
75 let span = body.span;
76 let source_info = SourceInfo::outermost(span);
77 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
78 let fut_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fut_ty);
79 let fut_ref_place = Place::from(body.local_decls.push(LocalDecl::new(fut_ref_ty, span)));
80 let pin_fut_new_unchecked_fn = Ty::new_fn_def(
81 tcx,
82 tcx.require_lang_item(LangItem::PinNewUnchecked, Some(span)),
83 [fut_ref_ty],
84 );
85 let fut_pin_ty = pin_fut_new_unchecked_fn.fn_sig(tcx).output().skip_binder();
86 let fut_pin_place = Place::from(body.local_decls.push(LocalDecl::new(fut_pin_ty, span)));
87 let pin_fut_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand {
88 span,
89 user_ty: None,
90 const_: Const::zero_sized(pin_fut_new_unchecked_fn),
91 }));
92
93 let storage_live =
94 Statement { source_info, kind: StatementKind::StorageLive(fut_pin_place.local) };
95
96 let fut_ref_assign = Statement {
97 source_info,
98 kind: StatementKind::Assign(Box::new((
99 fut_ref_place,
100 Rvalue::Ref(
101 tcx.lifetimes.re_erased,
102 BorrowKind::Mut { kind: MutBorrowKind::Default },
103 fut_place,
104 ),
105 ))),
106 };
107
108 let pin_fut_bb = body.basic_blocks_mut().push(BasicBlockData {
110 statements: [storage_live, fut_ref_assign].to_vec(),
111 terminator: Some(Terminator {
112 source_info,
113 kind: TerminatorKind::Call {
114 func: pin_fut_new_unchecked_fn,
115 args: [dummy_spanned(Operand::Move(fut_ref_place))].into(),
116 destination: fut_pin_place,
117 target: None, unwind,
119 call_source: CallSource::Misc,
120 fn_span: span,
121 },
122 }),
123 is_cleanup: false,
124 });
125 (pin_fut_bb, fut_pin_place)
126}
127
128fn build_poll_switch<'tcx>(
134 tcx: TyCtxt<'tcx>,
135 body: &mut Body<'tcx>,
136 poll_enum: Ty<'tcx>,
137 poll_unit_place: &Place<'tcx>,
138 ready_block: BasicBlock,
139 yield_block: BasicBlock,
140) -> BasicBlock {
141 let poll_enum_adt = poll_enum.ty_adt_def().unwrap();
142
143 let Discr { val: poll_ready_discr, ty: poll_discr_ty } = poll_enum
144 .discriminant_for_variant(
145 tcx,
146 poll_enum_adt.variant_index_with_id(tcx.require_lang_item(LangItem::PollReady, None)),
147 )
148 .unwrap();
149 let poll_pending_discr = poll_enum
150 .discriminant_for_variant(
151 tcx,
152 poll_enum_adt.variant_index_with_id(tcx.require_lang_item(LangItem::PollPending, None)),
153 )
154 .unwrap()
155 .val;
156 let source_info = SourceInfo::outermost(body.span);
157 let poll_discr_place =
158 Place::from(body.local_decls.push(LocalDecl::new(poll_discr_ty, source_info.span)));
159 let discr_assign = Statement {
160 source_info,
161 kind: StatementKind::Assign(Box::new((
162 poll_discr_place,
163 Rvalue::Discriminant(*poll_unit_place),
164 ))),
165 };
166 let unreachable_block = insert_term_block(body, TerminatorKind::Unreachable);
167 body.basic_blocks_mut().push(BasicBlockData {
168 statements: [discr_assign].to_vec(),
169 terminator: Some(Terminator {
170 source_info,
171 kind: TerminatorKind::SwitchInt {
172 discr: Operand::Move(poll_discr_place),
173 targets: SwitchTargets::new(
174 [(poll_ready_discr, ready_block), (poll_pending_discr, yield_block)]
175 .into_iter(),
176 unreachable_block,
177 ),
178 },
179 }),
180 is_cleanup: false,
181 })
182}
183
184fn gather_dropline_blocks<'tcx>(body: &mut Body<'tcx>) -> DenseBitSet<BasicBlock> {
186 let mut dropline: DenseBitSet<BasicBlock> = DenseBitSet::new_empty(body.basic_blocks.len());
187 for (bb, data) in traversal::reverse_postorder(body) {
188 if dropline.contains(bb) {
189 data.terminator().successors().for_each(|v| {
190 dropline.insert(v);
191 });
192 } else {
193 match data.terminator().kind {
194 TerminatorKind::Yield { drop: Some(v), .. } => {
195 dropline.insert(v);
196 }
197 TerminatorKind::Drop { drop: Some(v), .. } => {
198 dropline.insert(v);
199 }
200 _ => (),
201 }
202 }
203 }
204 dropline
205}
206
207pub(super) fn cleanup_async_drops<'tcx>(body: &mut Body<'tcx>) {
209 for block in body.basic_blocks_mut() {
210 if let TerminatorKind::Drop {
211 place: _,
212 target: _,
213 unwind: _,
214 replace: _,
215 ref mut drop,
216 ref mut async_fut,
217 } = block.terminator_mut().kind
218 {
219 if drop.is_some() || async_fut.is_some() {
220 *drop = None;
221 *async_fut = None;
222 }
223 }
224 }
225}
226
227pub(super) fn has_expandable_async_drops<'tcx>(
228 tcx: TyCtxt<'tcx>,
229 body: &mut Body<'tcx>,
230 coroutine_ty: Ty<'tcx>,
231) -> bool {
232 for bb in START_BLOCK..body.basic_blocks.next_index() {
233 if body[bb].is_cleanup {
235 continue;
236 }
237 let TerminatorKind::Drop { place, target: _, unwind: _, replace: _, drop: _, async_fut } =
238 body[bb].terminator().kind
239 else {
240 continue;
241 };
242 let place_ty = place.ty(&body.local_decls, tcx).ty;
243 if place_ty == coroutine_ty {
244 continue;
245 }
246 if async_fut.is_none() {
247 continue;
248 }
249 return true;
250 }
251 return false;
252}
253
254pub(super) fn expand_async_drops<'tcx>(
256 tcx: TyCtxt<'tcx>,
257 body: &mut Body<'tcx>,
258 context_mut_ref: Ty<'tcx>,
259 coroutine_kind: hir::CoroutineKind,
260 coroutine_ty: Ty<'tcx>,
261) {
262 let dropline = gather_dropline_blocks(body);
263 let remove_asyncness = |block: &mut BasicBlockData<'tcx>| {
265 if let TerminatorKind::Drop {
266 place: _,
267 target: _,
268 unwind: _,
269 replace: _,
270 ref mut drop,
271 ref mut async_fut,
272 } = block.terminator_mut().kind
273 {
274 *drop = None;
275 *async_fut = None;
276 }
277 };
278 for bb in START_BLOCK..body.basic_blocks.next_index() {
279 if body[bb].is_cleanup {
281 remove_asyncness(&mut body[bb]);
282 continue;
283 }
284 let TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut } =
285 body[bb].terminator().kind
286 else {
287 continue;
288 };
289
290 let place_ty = place.ty(&body.local_decls, tcx).ty;
291 if place_ty == coroutine_ty {
292 remove_asyncness(&mut body[bb]);
293 continue;
294 }
295
296 let Some(fut_local) = async_fut else {
297 remove_asyncness(&mut body[bb]);
298 continue;
299 };
300
301 let is_dropline_bb = dropline.contains(bb);
302
303 if !is_dropline_bb && drop.is_none() {
304 remove_asyncness(&mut body[bb]);
305 continue;
306 }
307
308 let fut_place = Place::from(fut_local);
309 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
310
311 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None));
321 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
322 let poll_decl = LocalDecl::new(poll_enum, body.span);
323 let poll_unit_place = Place::from(body.local_decls.push(poll_decl));
324
325 let context_ref_place =
327 Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, body.span)));
328 let source_info = body[bb].terminator.as_ref().unwrap().source_info;
329 let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG)));
330 body[bb].statements.push(Statement {
331 source_info,
332 kind: StatementKind::Assign(Box::new((context_ref_place, arg))),
333 });
334 let yield_block = insert_term_block(body, TerminatorKind::Unreachable); let switch_block =
336 build_poll_switch(tcx, body, poll_enum, &poll_unit_place, target, yield_block);
337 let (pin_bb, fut_pin_place) =
338 build_pin_fut(tcx, body, fut_place.clone(), UnwindAction::Continue);
339 let call_bb = build_poll_call(
340 tcx,
341 body,
342 &poll_unit_place,
343 switch_block,
344 &fut_pin_place,
345 fut_ty,
346 &context_ref_place,
347 unwind,
348 );
349
350 let mut dropline_transition_bb: Option<BasicBlock> = None;
352 let mut dropline_yield_bb: Option<BasicBlock> = None;
353 let mut dropline_context_ref: Option<Place<'_>> = None;
354 let mut dropline_call_bb: Option<BasicBlock> = None;
355 if !is_dropline_bb {
356 let context_ref_place2: Place<'_> =
357 Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, body.span)));
358 let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); let drop_switch_block = build_poll_switch(
360 tcx,
361 body,
362 poll_enum,
363 &poll_unit_place,
364 drop.unwrap(),
365 drop_yield_block,
366 );
367 let (pin_bb2, fut_pin_place2) =
368 build_pin_fut(tcx, body, fut_place, UnwindAction::Continue);
369 let drop_call_bb = build_poll_call(
370 tcx,
371 body,
372 &poll_unit_place,
373 drop_switch_block,
374 &fut_pin_place2,
375 fut_ty,
376 &context_ref_place2,
377 unwind,
378 );
379 dropline_transition_bb = Some(pin_bb2);
380 dropline_yield_bb = Some(drop_yield_block);
381 dropline_context_ref = Some(context_ref_place2);
382 dropline_call_bb = Some(drop_call_bb);
383 }
384
385 let value = Operand::Constant(Box::new(ConstOperand {
387 span: body.span,
388 user_ty: None,
389 const_: Const::from_bool(tcx, false),
390 }));
391 use rustc_middle::mir::AssertKind::ResumedAfterDrop;
392 let panic_bb = insert_panic_block(tcx, body, ResumedAfterDrop(coroutine_kind));
393
394 if is_dropline_bb {
395 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
396 value: value.clone(),
397 resume: panic_bb,
398 resume_arg: context_ref_place,
399 drop: Some(pin_bb),
400 };
401 } else {
402 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
403 value: value.clone(),
404 resume: pin_bb,
405 resume_arg: context_ref_place,
406 drop: dropline_transition_bb,
407 };
408 body[dropline_yield_bb.unwrap()].terminator_mut().kind = TerminatorKind::Yield {
409 value,
410 resume: panic_bb,
411 resume_arg: dropline_context_ref.unwrap(),
412 drop: dropline_transition_bb,
413 };
414 }
415
416 if let TerminatorKind::Call { ref mut target, .. } = body[pin_bb].terminator_mut().kind {
417 *target = Some(call_bb);
418 } else {
419 bug!()
420 }
421 if !is_dropline_bb {
422 if let TerminatorKind::Call { ref mut target, .. } =
423 body[dropline_transition_bb.unwrap()].terminator_mut().kind
424 {
425 *target = dropline_call_bb;
426 } else {
427 bug!()
428 }
429 }
430
431 body[bb].terminator_mut().kind = TerminatorKind::Goto { target: pin_bb };
432 }
433}
434
435pub(super) fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
436 use crate::elaborate_drop::{Unwind, elaborate_drop};
437 use crate::patch::MirPatch;
438 use crate::shim::DropShimElaborator;
439
440 let typing_env = body.typing_env(tcx);
444
445 let mut elaborator = DropShimElaborator {
446 body,
447 patch: MirPatch::new(body),
448 tcx,
449 typing_env,
450 produce_async_drops: false,
451 };
452
453 for (block, block_data) in body.basic_blocks.iter_enumerated() {
454 let (target, unwind, source_info, dropline) = match block_data.terminator() {
455 Terminator {
456 source_info,
457 kind: TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut: _ },
458 } => {
459 if let Some(local) = place.as_local()
460 && local == SELF_ARG
461 {
462 (target, unwind, source_info, *drop)
463 } else {
464 continue;
465 }
466 }
467 _ => continue,
468 };
469 let unwind = if block_data.is_cleanup {
470 Unwind::InCleanup
471 } else {
472 Unwind::To(match *unwind {
473 UnwindAction::Cleanup(tgt) => tgt,
474 UnwindAction::Continue => elaborator.patch.resume_block(),
475 UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(),
476 UnwindAction::Terminate(reason) => elaborator.patch.terminate_block(reason),
477 })
478 };
479 elaborate_drop(
480 &mut elaborator,
481 *source_info,
482 Place::from(SELF_ARG),
483 (),
484 *target,
485 unwind,
486 block,
487 dropline,
488 );
489 }
490 elaborator.patch.apply(body);
491}
492
493pub(super) fn insert_clean_drop<'tcx>(
494 tcx: TyCtxt<'tcx>,
495 body: &mut Body<'tcx>,
496 has_async_drops: bool,
497) -> BasicBlock {
498 let source_info = SourceInfo::outermost(body.span);
499 let return_block = if has_async_drops {
500 insert_poll_ready_block(tcx, body)
501 } else {
502 insert_term_block(body, TerminatorKind::Return)
503 };
504
505 let dropline = None;
509
510 let term = TerminatorKind::Drop {
511 place: Place::from(SELF_ARG),
512 target: return_block,
513 unwind: UnwindAction::Continue,
514 replace: false,
515 drop: dropline,
516 async_fut: None,
517 };
518
519 body.basic_blocks_mut().push(BasicBlockData {
521 statements: Vec::new(),
522 terminator: Some(Terminator { source_info, kind: term }),
523 is_cleanup: false,
524 })
525}
526
527pub(super) fn create_coroutine_drop_shim<'tcx>(
528 tcx: TyCtxt<'tcx>,
529 transform: &TransformVisitor<'tcx>,
530 coroutine_ty: Ty<'tcx>,
531 body: &Body<'tcx>,
532 drop_clean: BasicBlock,
533) -> Body<'tcx> {
534 let mut body = body.clone();
535 let _ = body.coroutine.take();
538 body.arg_count = 1;
541
542 let source_info = SourceInfo::outermost(body.span);
543
544 let mut cases = create_cases(&mut body, transform, Operation::Drop);
545
546 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
547
548 let default_block = insert_term_block(&mut body, TerminatorKind::Return);
552 insert_switch(&mut body, cases, transform, default_block);
553
554 for block in body.basic_blocks_mut() {
555 let kind = &mut block.terminator_mut().kind;
556 if let TerminatorKind::CoroutineDrop = *kind {
557 *kind = TerminatorKind::Return;
558 }
559 }
560
561 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info);
563
564 make_coroutine_state_argument_indirect(tcx, &mut body);
565
566 body.local_decls[SELF_ARG] =
568 LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info);
569
570 simplify::remove_dead_blocks(&mut body);
573
574 let coroutine_instance = body.source.instance;
576 let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, None);
577 let drop_instance = InstanceKind::DropGlue(drop_in_place, Some(coroutine_ty));
578
579 body.source.instance = coroutine_instance;
582 dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
583 body.source.instance = drop_instance;
584
585 body.phase = MirPhase::Runtime(RuntimePhase::Initial);
591
592 body
593}
594
595pub(super) fn create_coroutine_drop_shim_async<'tcx>(
597 tcx: TyCtxt<'tcx>,
598 transform: &TransformVisitor<'tcx>,
599 body: &Body<'tcx>,
600 drop_clean: BasicBlock,
601 can_unwind: bool,
602) -> Body<'tcx> {
603 let mut body = body.clone();
604 let _ = body.coroutine.take();
607
608 FixReturnPendingVisitor { tcx }.visit_body(&mut body);
609
610 if can_unwind {
612 generate_poison_block_and_redirect_unwinds_there(transform, &mut body);
613 }
614
615 let source_info = SourceInfo::outermost(body.span);
616
617 let mut cases = create_cases(&mut body, transform, Operation::Drop);
618
619 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
620
621 use rustc_middle::mir::AssertKind::ResumedAfterPanic;
622 if can_unwind {
624 cases.insert(
625 1,
626 (
627 CoroutineArgs::POISONED,
628 insert_panic_block(tcx, &mut body, ResumedAfterPanic(transform.coroutine_kind)),
629 ),
630 );
631 }
632
633 let default_block = insert_poll_ready_block(tcx, &mut body);
636 insert_switch(&mut body, cases, transform, default_block);
637
638 for block in body.basic_blocks_mut() {
639 let kind = &mut block.terminator_mut().kind;
640 if let TerminatorKind::CoroutineDrop = *kind {
641 *kind = TerminatorKind::Return;
642 block.statements.push(return_poll_ready_assign(tcx, source_info));
643 }
644 }
645
646 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None));
648 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
649 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
650
651 make_coroutine_state_argument_indirect(tcx, &mut body);
652
653 match transform.coroutine_kind {
654 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
657 _ => {
658 make_coroutine_state_argument_pinned(tcx, &mut body);
659 }
660 }
661
662 simplify::remove_dead_blocks(&mut body);
665
666 pm::run_passes_no_validate(
667 tcx,
668 &mut body,
669 &[&abort_unwinding_calls::AbortUnwindingCalls],
670 None,
671 );
672
673 dump_mir(tcx, false, "coroutine_drop_async", &0, &body, |_, _| Ok(()));
674
675 body
676}
677
678pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
681 tcx: TyCtxt<'tcx>,
682 body: &Body<'tcx>,
683) -> Body<'tcx> {
684 let mut body = body.clone();
685 let _ = body.coroutine.take();
688 let basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>> = IndexVec::new();
689 body.basic_blocks = BasicBlocks::new(basic_blocks);
690 body.var_debug_info.clear();
691
692 body.local_decls.truncate(1 + body.arg_count);
694
695 let source_info = SourceInfo::outermost(body.span);
696
697 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None));
699 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
700 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
701
702 let call_bb = body.basic_blocks_mut().push(BasicBlockData {
704 statements: Vec::new(),
705 terminator: None,
706 is_cleanup: false,
707 });
708
709 let ret_bb = insert_poll_ready_block(tcx, &mut body);
711
712 let kind = TerminatorKind::Drop {
713 place: Place::from(SELF_ARG),
714 target: ret_bb,
715 unwind: UnwindAction::Continue,
716 replace: false,
717 drop: None,
718 async_fut: None,
719 };
720 body.basic_blocks_mut()[call_bb].terminator = Some(Terminator { source_info, kind });
721
722 dump_mir(tcx, false, "coroutine_drop_proxy_async", &0, &body, |_, _| Ok(()));
723
724 body
725}