1use super::*;
4
5struct FixReturnPendingVisitor<'tcx> {
7 tcx: TyCtxt<'tcx>,
8}
9
10impl<'tcx> MutVisitor<'tcx> for FixReturnPendingVisitor<'tcx> {
11 fn tcx(&self) -> TyCtxt<'tcx> {
12 self.tcx
13 }
14
15 fn visit_assign(
16 &mut self,
17 place: &mut Place<'tcx>,
18 rvalue: &mut Rvalue<'tcx>,
19 _location: Location,
20 ) {
21 if place.local != RETURN_PLACE {
22 return;
23 }
24
25 if let Rvalue::Aggregate(kind, _) = rvalue {
27 if let AggregateKind::Adt(_, _, ref mut args, _, _) = **kind {
28 *args = self.tcx.mk_args(&[self.tcx.types.unit.into()]);
29 }
30 }
31 }
32}
33
34fn build_poll_call<'tcx>(
36 tcx: TyCtxt<'tcx>,
37 body: &mut Body<'tcx>,
38 poll_unit_place: &Place<'tcx>,
39 switch_block: BasicBlock,
40 fut_pin_place: &Place<'tcx>,
41 fut_ty: Ty<'tcx>,
42 context_ref_place: &Place<'tcx>,
43 unwind: UnwindAction,
44) -> BasicBlock {
45 let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, DUMMY_SP);
46 let poll_fn = Ty::new_fn_def(tcx, poll_fn, [fut_ty]);
47 let poll_fn = Operand::Constant(Box::new(ConstOperand {
48 span: DUMMY_SP,
49 user_ty: None,
50 const_: Const::zero_sized(poll_fn),
51 }));
52 let call = TerminatorKind::Call {
53 func: poll_fn.clone(),
54 args: [
55 dummy_spanned(Operand::Move(*fut_pin_place)),
56 dummy_spanned(Operand::Move(*context_ref_place)),
57 ]
58 .into(),
59 destination: *poll_unit_place,
60 target: Some(switch_block),
61 unwind,
62 call_source: CallSource::Misc,
63 fn_span: DUMMY_SP,
64 };
65 insert_term_block(body, call)
66}
67
68fn build_pin_fut<'tcx>(
70 tcx: TyCtxt<'tcx>,
71 body: &mut Body<'tcx>,
72 fut_place: Place<'tcx>,
73 unwind: UnwindAction,
74) -> (BasicBlock, Place<'tcx>) {
75 let span = body.span;
76 let source_info = SourceInfo::outermost(span);
77 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
78 let fut_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fut_ty);
79 let fut_ref_place = Place::from(body.local_decls.push(LocalDecl::new(fut_ref_ty, span)));
80 let pin_fut_new_unchecked_fn =
81 Ty::new_fn_def(tcx, tcx.require_lang_item(LangItem::PinNewUnchecked, span), [fut_ref_ty]);
82 let fut_pin_ty = pin_fut_new_unchecked_fn.fn_sig(tcx).output().skip_binder();
83 let fut_pin_place = Place::from(body.local_decls.push(LocalDecl::new(fut_pin_ty, span)));
84 let pin_fut_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand {
85 span,
86 user_ty: None,
87 const_: Const::zero_sized(pin_fut_new_unchecked_fn),
88 }));
89
90 let storage_live = Statement::new(source_info, StatementKind::StorageLive(fut_pin_place.local));
91
92 let fut_ref_assign = Statement::new(
93 source_info,
94 StatementKind::Assign(Box::new((
95 fut_ref_place,
96 Rvalue::Ref(
97 tcx.lifetimes.re_erased,
98 BorrowKind::Mut { kind: MutBorrowKind::Default },
99 fut_place,
100 ),
101 ))),
102 );
103
104 let pin_fut_bb = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
106 [storage_live, fut_ref_assign].to_vec(),
107 Some(Terminator {
108 source_info,
109 kind: TerminatorKind::Call {
110 func: pin_fut_new_unchecked_fn,
111 args: [dummy_spanned(Operand::Move(fut_ref_place))].into(),
112 destination: fut_pin_place,
113 target: None, unwind,
115 call_source: CallSource::Misc,
116 fn_span: span,
117 },
118 }),
119 false,
120 ));
121 (pin_fut_bb, fut_pin_place)
122}
123
124fn build_poll_switch<'tcx>(
130 tcx: TyCtxt<'tcx>,
131 body: &mut Body<'tcx>,
132 poll_enum: Ty<'tcx>,
133 poll_unit_place: &Place<'tcx>,
134 fut_pin_place: &Place<'tcx>,
135 ready_block: BasicBlock,
136 yield_block: BasicBlock,
137) -> BasicBlock {
138 let poll_enum_adt = poll_enum.ty_adt_def().unwrap();
139
140 let Discr { val: poll_ready_discr, ty: poll_discr_ty } = poll_enum
141 .discriminant_for_variant(
142 tcx,
143 poll_enum_adt
144 .variant_index_with_id(tcx.require_lang_item(LangItem::PollReady, DUMMY_SP)),
145 )
146 .unwrap();
147 let poll_pending_discr = poll_enum
148 .discriminant_for_variant(
149 tcx,
150 poll_enum_adt
151 .variant_index_with_id(tcx.require_lang_item(LangItem::PollPending, DUMMY_SP)),
152 )
153 .unwrap()
154 .val;
155 let source_info = SourceInfo::outermost(body.span);
156 let poll_discr_place =
157 Place::from(body.local_decls.push(LocalDecl::new(poll_discr_ty, source_info.span)));
158 let discr_assign = Statement::new(
159 source_info,
160 StatementKind::Assign(Box::new((poll_discr_place, Rvalue::Discriminant(*poll_unit_place)))),
161 );
162 let storage_dead = Statement::new(source_info, StatementKind::StorageDead(fut_pin_place.local));
163 let unreachable_block = insert_term_block(body, TerminatorKind::Unreachable);
164 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
165 [storage_dead, discr_assign].to_vec(),
166 Some(Terminator {
167 source_info,
168 kind: TerminatorKind::SwitchInt {
169 discr: Operand::Move(poll_discr_place),
170 targets: SwitchTargets::new(
171 [(poll_ready_discr, ready_block), (poll_pending_discr, yield_block)]
172 .into_iter(),
173 unreachable_block,
174 ),
175 },
176 }),
177 false,
178 ))
179}
180
181fn gather_dropline_blocks<'tcx>(body: &mut Body<'tcx>) -> DenseBitSet<BasicBlock> {
183 let mut dropline: DenseBitSet<BasicBlock> = DenseBitSet::new_empty(body.basic_blocks.len());
184 for (bb, data) in traversal::reverse_postorder(body) {
185 if dropline.contains(bb) {
186 data.terminator().successors().for_each(|v| {
187 dropline.insert(v);
188 });
189 } else {
190 match data.terminator().kind {
191 TerminatorKind::Yield { drop: Some(v), .. } => {
192 dropline.insert(v);
193 }
194 TerminatorKind::Drop { drop: Some(v), .. } => {
195 dropline.insert(v);
196 }
197 _ => (),
198 }
199 }
200 }
201 dropline
202}
203
204pub(super) fn cleanup_async_drops<'tcx>(body: &mut Body<'tcx>) {
206 for block in body.basic_blocks_mut() {
207 if let TerminatorKind::Drop {
208 place: _,
209 target: _,
210 unwind: _,
211 replace: _,
212 ref mut drop,
213 ref mut async_fut,
214 } = block.terminator_mut().kind
215 {
216 if drop.is_some() || async_fut.is_some() {
217 *drop = None;
218 *async_fut = None;
219 }
220 }
221 }
222}
223
224pub(super) fn has_expandable_async_drops<'tcx>(
225 tcx: TyCtxt<'tcx>,
226 body: &mut Body<'tcx>,
227 coroutine_ty: Ty<'tcx>,
228) -> bool {
229 for bb in START_BLOCK..body.basic_blocks.next_index() {
230 if body[bb].is_cleanup {
232 continue;
233 }
234 let TerminatorKind::Drop { place, target: _, unwind: _, replace: _, drop: _, async_fut } =
235 body[bb].terminator().kind
236 else {
237 continue;
238 };
239 let place_ty = place.ty(&body.local_decls, tcx).ty;
240 if place_ty == coroutine_ty {
241 continue;
242 }
243 if async_fut.is_none() {
244 continue;
245 }
246 return true;
247 }
248 return false;
249}
250
251pub(super) fn expand_async_drops<'tcx>(
253 tcx: TyCtxt<'tcx>,
254 body: &mut Body<'tcx>,
255 context_mut_ref: Ty<'tcx>,
256 coroutine_kind: hir::CoroutineKind,
257 coroutine_ty: Ty<'tcx>,
258) {
259 let dropline = gather_dropline_blocks(body);
260 let remove_asyncness = |block: &mut BasicBlockData<'tcx>| {
262 if let TerminatorKind::Drop {
263 place: _,
264 target: _,
265 unwind: _,
266 replace: _,
267 ref mut drop,
268 ref mut async_fut,
269 } = block.terminator_mut().kind
270 {
271 *drop = None;
272 *async_fut = None;
273 }
274 };
275 for bb in START_BLOCK..body.basic_blocks.next_index() {
276 if body[bb].is_cleanup {
278 remove_asyncness(&mut body[bb]);
279 continue;
280 }
281 let TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut } =
282 body[bb].terminator().kind
283 else {
284 continue;
285 };
286
287 let place_ty = place.ty(&body.local_decls, tcx).ty;
288 if place_ty == coroutine_ty {
289 remove_asyncness(&mut body[bb]);
290 continue;
291 }
292
293 let Some(fut_local) = async_fut else {
294 remove_asyncness(&mut body[bb]);
295 continue;
296 };
297
298 let is_dropline_bb = dropline.contains(bb);
299
300 if !is_dropline_bb && drop.is_none() {
301 remove_asyncness(&mut body[bb]);
302 continue;
303 }
304
305 let fut_place = Place::from(fut_local);
306 let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
307
308 let source_info = body[bb].terminator.as_ref().unwrap().source_info;
317
318 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, source_info.span));
320 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
321 let poll_decl = LocalDecl::new(poll_enum, source_info.span);
322 let poll_unit_place = Place::from(body.local_decls.push(poll_decl));
323
324 let context_ref_place =
326 Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)));
327 let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG)));
328 body[bb].statements.push(Statement::new(
329 source_info,
330 StatementKind::Assign(Box::new((context_ref_place, arg))),
331 ));
332 let yield_block = insert_term_block(body, TerminatorKind::Unreachable); let (pin_bb, fut_pin_place) =
334 build_pin_fut(tcx, body, fut_place.clone(), UnwindAction::Continue);
335 let switch_block = build_poll_switch(
336 tcx,
337 body,
338 poll_enum,
339 &poll_unit_place,
340 &fut_pin_place,
341 target,
342 yield_block,
343 );
344 let call_bb = build_poll_call(
345 tcx,
346 body,
347 &poll_unit_place,
348 switch_block,
349 &fut_pin_place,
350 fut_ty,
351 &context_ref_place,
352 unwind,
353 );
354
355 let mut dropline_transition_bb: Option<BasicBlock> = None;
357 let mut dropline_yield_bb: Option<BasicBlock> = None;
358 let mut dropline_context_ref: Option<Place<'_>> = None;
359 let mut dropline_call_bb: Option<BasicBlock> = None;
360 if !is_dropline_bb {
361 let context_ref_place2: Place<'_> = Place::from(
362 body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)),
363 );
364 let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); let (pin_bb2, fut_pin_place2) =
366 build_pin_fut(tcx, body, fut_place, UnwindAction::Continue);
367 let drop_switch_block = build_poll_switch(
368 tcx,
369 body,
370 poll_enum,
371 &poll_unit_place,
372 &fut_pin_place2,
373 drop.unwrap(),
374 drop_yield_block,
375 );
376 let drop_call_bb = build_poll_call(
377 tcx,
378 body,
379 &poll_unit_place,
380 drop_switch_block,
381 &fut_pin_place2,
382 fut_ty,
383 &context_ref_place2,
384 unwind,
385 );
386 dropline_transition_bb = Some(pin_bb2);
387 dropline_yield_bb = Some(drop_yield_block);
388 dropline_context_ref = Some(context_ref_place2);
389 dropline_call_bb = Some(drop_call_bb);
390 }
391
392 let value =
393 if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
394 {
395 let full_yield_ty = body.yield_ty().unwrap();
397 let ty::Adt(_poll_adt, args) = *full_yield_ty.kind() else { bug!() };
398 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
399 let yield_ty = args.type_at(0);
400 Operand::Constant(Box::new(ConstOperand {
401 span: source_info.span,
402 const_: Const::Unevaluated(
403 UnevaluatedConst::new(
404 tcx.require_lang_item(LangItem::AsyncGenPending, source_info.span),
405 tcx.mk_args(&[yield_ty.into()]),
406 ),
407 full_yield_ty,
408 ),
409 user_ty: None,
410 }))
411 } else {
412 Operand::Constant(Box::new(ConstOperand {
414 span: source_info.span,
415 user_ty: None,
416 const_: Const::from_bool(tcx, false),
417 }))
418 };
419
420 use rustc_middle::mir::AssertKind::ResumedAfterDrop;
421 let panic_bb = insert_panic_block(tcx, body, ResumedAfterDrop(coroutine_kind));
422
423 if is_dropline_bb {
424 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
425 value: value.clone(),
426 resume: panic_bb,
427 resume_arg: context_ref_place,
428 drop: Some(pin_bb),
429 };
430 } else {
431 body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
432 value: value.clone(),
433 resume: pin_bb,
434 resume_arg: context_ref_place,
435 drop: dropline_transition_bb,
436 };
437 body[dropline_yield_bb.unwrap()].terminator_mut().kind = TerminatorKind::Yield {
438 value,
439 resume: panic_bb,
440 resume_arg: dropline_context_ref.unwrap(),
441 drop: dropline_transition_bb,
442 };
443 }
444
445 if let TerminatorKind::Call { ref mut target, .. } = body[pin_bb].terminator_mut().kind {
446 *target = Some(call_bb);
447 } else {
448 bug!()
449 }
450 if !is_dropline_bb {
451 if let TerminatorKind::Call { ref mut target, .. } =
452 body[dropline_transition_bb.unwrap()].terminator_mut().kind
453 {
454 *target = dropline_call_bb;
455 } else {
456 bug!()
457 }
458 }
459
460 body[bb].terminator_mut().kind = TerminatorKind::Goto { target: pin_bb };
461 }
462}
463
464pub(super) fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
465 use crate::elaborate_drop::{Unwind, elaborate_drop};
466 use crate::patch::MirPatch;
467 use crate::shim::DropShimElaborator;
468
469 let typing_env = body.typing_env(tcx);
473
474 let mut elaborator = DropShimElaborator {
475 body,
476 patch: MirPatch::new(body),
477 tcx,
478 typing_env,
479 produce_async_drops: false,
480 };
481
482 for (block, block_data) in body.basic_blocks.iter_enumerated() {
483 let (target, unwind, source_info, dropline) = match block_data.terminator() {
484 Terminator {
485 source_info,
486 kind: TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut: _ },
487 } => {
488 if let Some(local) = place.as_local()
489 && local == SELF_ARG
490 {
491 (target, unwind, source_info, *drop)
492 } else {
493 continue;
494 }
495 }
496 _ => continue,
497 };
498 let unwind = if block_data.is_cleanup {
499 Unwind::InCleanup
500 } else {
501 Unwind::To(match *unwind {
502 UnwindAction::Cleanup(tgt) => tgt,
503 UnwindAction::Continue => elaborator.patch.resume_block(),
504 UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(),
505 UnwindAction::Terminate(reason) => elaborator.patch.terminate_block(reason),
506 })
507 };
508 elaborate_drop(
509 &mut elaborator,
510 *source_info,
511 Place::from(SELF_ARG),
512 (),
513 *target,
514 unwind,
515 block,
516 dropline,
517 );
518 }
519 elaborator.patch.apply(body);
520}
521
522pub(super) fn insert_clean_drop<'tcx>(
523 tcx: TyCtxt<'tcx>,
524 body: &mut Body<'tcx>,
525 has_async_drops: bool,
526) -> BasicBlock {
527 let source_info = SourceInfo::outermost(body.span);
528 let return_block = if has_async_drops {
529 insert_poll_ready_block(tcx, body)
530 } else {
531 insert_term_block(body, TerminatorKind::Return)
532 };
533
534 let dropline = None;
538
539 let term = TerminatorKind::Drop {
540 place: Place::from(SELF_ARG),
541 target: return_block,
542 unwind: UnwindAction::Continue,
543 replace: false,
544 drop: dropline,
545 async_fut: None,
546 };
547
548 body.basic_blocks_mut()
550 .push(BasicBlockData::new(Some(Terminator { source_info, kind: term }), false))
551}
552
553pub(super) fn create_coroutine_drop_shim<'tcx>(
554 tcx: TyCtxt<'tcx>,
555 transform: &TransformVisitor<'tcx>,
556 coroutine_ty: Ty<'tcx>,
557 body: &Body<'tcx>,
558 drop_clean: BasicBlock,
559) -> Body<'tcx> {
560 let mut body = body.clone();
561 let _ = body.coroutine.take();
564 body.arg_count = 1;
567
568 let source_info = SourceInfo::outermost(body.span);
569
570 let mut cases = create_cases(&mut body, transform, Operation::Drop);
571
572 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
573
574 let default_block = insert_term_block(&mut body, TerminatorKind::Return);
578 insert_switch(&mut body, cases, transform, default_block);
579
580 for block in body.basic_blocks_mut() {
581 let kind = &mut block.terminator_mut().kind;
582 if let TerminatorKind::CoroutineDrop = *kind {
583 *kind = TerminatorKind::Return;
584 }
585 }
586
587 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info);
589
590 make_coroutine_state_argument_indirect(tcx, &mut body);
591
592 body.local_decls[SELF_ARG] =
594 LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info);
595
596 simplify::remove_dead_blocks(&mut body);
599
600 let coroutine_instance = body.source.instance;
602 let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, body.span);
603 let drop_instance = InstanceKind::DropGlue(drop_in_place, Some(coroutine_ty));
604
605 body.source.instance = coroutine_instance;
608 dump_mir(tcx, false, "coroutine_drop", &0, &body, |_, _| Ok(()));
609 body.source.instance = drop_instance;
610
611 body.phase = MirPhase::Runtime(RuntimePhase::Initial);
617
618 body
619}
620
621pub(super) fn create_coroutine_drop_shim_async<'tcx>(
623 tcx: TyCtxt<'tcx>,
624 transform: &TransformVisitor<'tcx>,
625 body: &Body<'tcx>,
626 drop_clean: BasicBlock,
627 can_unwind: bool,
628) -> Body<'tcx> {
629 let mut body = body.clone();
630 let _ = body.coroutine.take();
633
634 FixReturnPendingVisitor { tcx }.visit_body(&mut body);
635
636 if can_unwind {
638 generate_poison_block_and_redirect_unwinds_there(transform, &mut body);
639 }
640
641 let source_info = SourceInfo::outermost(body.span);
642
643 let mut cases = create_cases(&mut body, transform, Operation::Drop);
644
645 cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
646
647 use rustc_middle::mir::AssertKind::ResumedAfterPanic;
648 if can_unwind {
650 cases.insert(
651 1,
652 (
653 CoroutineArgs::POISONED,
654 insert_panic_block(tcx, &mut body, ResumedAfterPanic(transform.coroutine_kind)),
655 ),
656 );
657 }
658
659 let default_block = insert_poll_ready_block(tcx, &mut body);
662 insert_switch(&mut body, cases, transform, default_block);
663
664 for block in body.basic_blocks_mut() {
665 let kind = &mut block.terminator_mut().kind;
666 if let TerminatorKind::CoroutineDrop = *kind {
667 *kind = TerminatorKind::Return;
668 block.statements.push(return_poll_ready_assign(tcx, source_info));
669 }
670 }
671
672 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
674 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
675 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
676
677 make_coroutine_state_argument_indirect(tcx, &mut body);
678
679 match transform.coroutine_kind {
680 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
683 _ => {
684 make_coroutine_state_argument_pinned(tcx, &mut body);
685 }
686 }
687
688 simplify::remove_dead_blocks(&mut body);
691
692 pm::run_passes_no_validate(
693 tcx,
694 &mut body,
695 &[&abort_unwinding_calls::AbortUnwindingCalls],
696 None,
697 );
698
699 dump_mir(tcx, false, "coroutine_drop_async", &0, &body, |_, _| Ok(()));
700
701 body
702}
703
704pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
707 tcx: TyCtxt<'tcx>,
708 body: &Body<'tcx>,
709) -> Body<'tcx> {
710 let mut body = body.clone();
711 let _ = body.coroutine.take();
714 let basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>> = IndexVec::new();
715 body.basic_blocks = BasicBlocks::new(basic_blocks);
716 body.var_debug_info.clear();
717
718 body.local_decls.truncate(1 + body.arg_count);
720
721 let source_info = SourceInfo::outermost(body.span);
722
723 let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
725 let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
726 body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
727
728 let call_bb = body.basic_blocks_mut().push(BasicBlockData::new(None, false));
730
731 let ret_bb = insert_poll_ready_block(tcx, &mut body);
733
734 let kind = TerminatorKind::Drop {
735 place: Place::from(SELF_ARG),
736 target: ret_bb,
737 unwind: UnwindAction::Continue,
738 replace: false,
739 drop: None,
740 async_fut: None,
741 };
742 body.basic_blocks_mut()[call_bb].terminator = Some(Terminator { source_info, kind });
743
744 dump_mir(tcx, false, "coroutine_drop_proxy_async", &0, &body, |_, _| Ok(()));
745
746 body
747}