1use std::mem;
42
43use crate::translate::translate_bodies::BodyTransCtx;
44
45use super::translate_ctx::*;
46use charon_lib::ast::*;
47use charon_lib::ids::Vector;
48use charon_lib::ullbc_ast::*;
49use hax_frontend_exporter as hax;
50use itertools::Itertools;
51
52pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
53 match kind {
54 hax::ClosureKind::Fn => ClosureKind::Fn,
55 hax::ClosureKind::FnMut => ClosureKind::FnMut,
56 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
57 }
58}
59
60impl ItemTransCtx<'_, '_> {
61 pub fn translate_closure_info(
62 &mut self,
63 span: Span,
64 args: &hax::ClosureArgs,
65 ) -> Result<ClosureInfo, Error> {
66 use ClosureKind::*;
67 let kind = translate_closure_kind(&args.kind);
68
69 let fn_once_impl = self.translate_closure_bound_impl_ref(span, args, FnOnce)?;
70 let fn_mut_impl = if matches!(kind, FnMut | Fn) {
71 Some(self.translate_closure_bound_impl_ref(span, args, FnMut)?)
72 } else {
73 None
74 };
75 let fn_impl = if matches!(kind, Fn) {
76 Some(self.translate_closure_bound_impl_ref(span, args, Fn)?)
77 } else {
78 None
79 };
80 let signature = self.translate_region_binder(span, &args.fn_sig, |ctx, sig| {
81 let inputs = sig
82 .inputs
83 .iter()
84 .map(|x| ctx.translate_ty(span, x))
85 .try_collect()?;
86 let output = ctx.translate_ty(span, &sig.output)?;
87 Ok((inputs, output))
88 })?;
89 Ok(ClosureInfo {
90 kind,
91 fn_once_impl,
92 fn_mut_impl,
93 fn_impl,
94 signature,
95 })
96 }
97
98 pub fn translate_closure_bound_type_ref(
102 &mut self,
103 span: Span,
104 closure: &hax::ClosureArgs,
105 ) -> Result<RegionBinder<TypeDeclRef>, Error> {
106 let upvar_binder = hax::Binder {
108 value: (),
109 bound_vars: closure
110 .upvar_tys
111 .iter()
112 .filter(|ty| {
113 matches!(
114 ty.kind(),
115 hax::TyKind::Ref(
116 hax::Region {
117 kind: hax::RegionKind::ReErased
118 },
119 ..
120 )
121 )
122 })
123 .map(|_| hax::BoundVariableKind::Region(hax::BoundRegionKind::Anon))
124 .collect(),
125 };
126 let tref = self.translate_type_decl_ref(span, &closure.item)?;
127 self.translate_region_binder(span, &upvar_binder, |ctx, _| {
128 let mut tref = tref.move_under_binder();
129 tref.generics.regions.extend(
130 ctx.innermost_binder()
131 .params
132 .identity_args()
133 .regions
134 .into_iter(),
135 );
136 Ok(tref)
137 })
138 }
139
140 pub fn translate_closure_type_ref(
142 &mut self,
143 span: Span,
144 closure: &hax::ClosureArgs,
145 ) -> Result<TypeDeclRef, Error> {
146 let bound_tref = self.translate_closure_bound_type_ref(span, closure)?;
147 let tref = if self.item_src.as_def_id() == &closure.item.def_id {
148 bound_tref.apply(
150 self.outermost_binder()
151 .by_ref_upvar_regions
152 .iter()
153 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r)))
154 .collect(),
155 )
156 } else {
157 bound_tref.erase()
159 };
160 Ok(tref)
161 }
162
163 pub fn translate_stateless_closure_as_fn_ref(
166 &mut self,
167 span: Span,
168 closure: &hax::ClosureArgs,
169 ) -> Result<RegionBinder<FunDeclRef>, Error> {
170 let id = self.register_closure_as_fun_decl_id(span, &closure.item.def_id);
171 let TypeDeclRef { generics, .. } = self.translate_closure_type_ref(span, closure)?;
172 self.translate_region_binder(span, &closure.fn_sig, |ctx, _| {
173 let mut generics = generics.move_under_binder();
174 generics.regions.extend(
175 ctx.innermost_binder()
176 .params
177 .identity_args()
178 .regions
179 .into_iter(),
180 );
181 Ok(FunDeclRef { id, generics })
182 })
183 }
184
185 pub fn translate_closure_bound_impl_ref(
189 &mut self,
190 span: Span,
191 closure: &hax::ClosureArgs,
192 target_kind: ClosureKind,
193 ) -> Result<RegionBinder<TraitImplRef>, Error> {
194 let impl_id = self.register_closure_trait_impl_id(span, &closure.item.def_id, target_kind);
195 let adt_ref = self.translate_closure_type_ref(span, closure)?;
196 let impl_ref = TraitImplRef {
197 id: impl_id,
198 generics: adt_ref.generics,
199 };
200 self.translate_region_binder(span, &closure.fn_sig, |ctx, _| {
201 let mut impl_ref = impl_ref.move_under_binder();
202 impl_ref.generics.regions.extend(
203 ctx.innermost_binder()
204 .params
205 .identity_args()
206 .regions
207 .into_iter(),
208 );
209 Ok(impl_ref)
210 })
211 }
212
213 pub fn translate_closure_impl_ref(
215 &mut self,
216 span: Span,
217 closure: &hax::ClosureArgs,
218 target_kind: ClosureKind,
219 ) -> Result<TraitImplRef, Error> {
220 let bound_impl_ref = self.translate_closure_bound_impl_ref(span, closure, target_kind)?;
221 let impl_ref = if self.item_src.as_def_id() == &closure.item.def_id {
222 bound_impl_ref.apply(
224 self.outermost_binder()
225 .bound_region_vars
226 .iter()
227 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r)))
228 .collect(),
229 )
230 } else {
231 bound_impl_ref.erase()
232 };
233 Ok(impl_ref)
234 }
235
236 pub fn translate_closure_trait_ref(
238 &mut self,
239 span: Span,
240 args: &hax::ClosureArgs,
241 target_kind: ClosureKind,
242 ) -> Result<TraitDeclRef, Error> {
243 let fn_trait = match target_kind {
245 ClosureKind::FnOnce => self.get_lang_item(rustc_hir::LangItem::FnOnce),
246 ClosureKind::FnMut => self.get_lang_item(rustc_hir::LangItem::FnMut),
247 ClosureKind::Fn => self.get_lang_item(rustc_hir::LangItem::Fn),
248 };
249 let trait_id = self.register_trait_decl_id(span, &fn_trait);
250
251 let state_ty = self.get_closure_state_ty(span, args)?;
252 let (inputs, _) = self.translate_closure_info(span, args)?.signature.erase();
254 let input_tuple = Ty::mk_tuple(inputs);
255
256 Ok(TraitDeclRef {
257 id: trait_id,
258 generics: Box::new(GenericArgs::new_types([state_ty, input_tuple].into())),
259 })
260 }
261
262 pub fn get_closure_state_ty(
263 &mut self,
264 span: Span,
265 args: &hax::ClosureArgs,
266 ) -> Result<Ty, Error> {
267 let tref = self.translate_closure_type_ref(span, args)?;
268 Ok(TyKind::Adt(tref).into_ty())
269 }
270
271 pub fn translate_closure_adt(
272 &mut self,
273 _trans_id: TypeDeclId,
274 span: Span,
275 args: &hax::ClosureArgs,
276 ) -> Result<TypeDeclKind, Error> {
277 let mut by_ref_upvar_regions = self
278 .the_only_binder()
279 .by_ref_upvar_regions
280 .clone()
281 .into_iter();
282 let fields: Vector<FieldId, Field> = args
283 .upvar_tys
284 .iter()
285 .map(|ty| {
286 let mut ty = self.translate_ty(span, ty)?;
287 if let TyKind::Ref(Region::Erased, deref_ty, kind) = ty.kind() {
289 let region_id = by_ref_upvar_regions.next().unwrap();
290 ty = TyKind::Ref(
291 Region::Var(DeBruijnVar::new_at_zero(region_id)),
292 deref_ty.clone(),
293 *kind,
294 )
295 .into_ty();
296 }
297 Ok(Field {
298 span,
299 attr_info: AttrInfo {
300 attributes: vec![],
301 inline: None,
302 rename: None,
303 public: false,
304 },
305 name: None,
306 ty,
307 })
308 })
309 .try_collect()?;
310 Ok(TypeDeclKind::Struct(fields))
311 }
312
313 fn translate_closure_method_sig(
316 &mut self,
317 def: &hax::FullDef,
318 span: Span,
319 args: &hax::ClosureArgs,
320 target_kind: ClosureKind,
321 ) -> Result<FunSig, Error> {
322 let signature = &args.fn_sig;
323 trace!(
324 "signature of closure {:?}:\n{:?}",
325 def.def_id, signature.value,
326 );
327
328 let is_unsafe = match signature.value.safety {
329 hax::Safety::Unsafe => true,
330 hax::Safety::Safe => false,
331 };
332
333 let state_ty = self.get_closure_state_ty(span, args)?;
334
335 let state_ty = match target_kind {
337 ClosureKind::FnOnce => state_ty,
338 ClosureKind::Fn | ClosureKind::FnMut => {
339 let rid = self
340 .innermost_generics_mut()
341 .regions
342 .push_with(|index| RegionVar { index, name: None });
343 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
344 let mutability = if target_kind == ClosureKind::Fn {
345 RefKind::Shared
346 } else {
347 RefKind::Mut
348 };
349 TyKind::Ref(r, state_ty, mutability).into_ty()
350 }
351 };
352
353 let input_tys: Vec<Ty> = signature
355 .value
356 .inputs
357 .iter()
358 .map(|ty| self.translate_ty(span, ty))
359 .try_collect()?;
360 let inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
362 let output = self.translate_ty(span, &signature.value.output)?;
363
364 Ok(FunSig {
365 generics: self.the_only_binder().params.clone(),
366 is_unsafe,
367 inputs,
368 output,
369 })
370 }
371
372 fn translate_closure_method_body(
373 mut self,
374 span: Span,
375 def: &hax::FullDef,
376 target_kind: ClosureKind,
377 args: &hax::ClosureArgs,
378 signature: &FunSig,
379 ) -> Result<Result<Body, Opaque>, Error> {
380 use ClosureKind::*;
381 let closure_kind = translate_closure_kind(&args.kind);
382 let mk_stt = |content| Statement::new(span, content);
383 let mk_block = |statements, terminator| -> BlockData {
384 BlockData {
385 statements,
386 terminator: Terminator::new(span, terminator),
387 }
388 };
389
390 Ok(match (target_kind, closure_kind) {
391 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
392 let mut bt_ctx = BodyTransCtx::new(&mut self);
394 match bt_ctx.translate_def_body(span, def) {
395 Ok(Ok(mut body)) => {
396 let GExprBody {
405 locals,
406 body: blocks,
407 ..
408 } = body.as_unstructured_mut().unwrap();
409
410 blocks.dyn_visit_mut(|local: &mut LocalId| {
411 let idx = local.index();
412 if idx >= 2 {
413 *local = LocalId::new(idx + 1)
414 }
415 });
416
417 let mut old_locals = mem::take(&mut locals.locals).into_iter();
418 locals.arg_count = 2;
419 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg = locals
422 .new_var(Some("tupled_args".to_string()), signature.inputs[1].clone());
423 locals.locals.extend(old_locals.map(|mut l| {
424 l.index += 1;
425 l
426 }));
427
428 let untupled_args = signature.inputs[1].as_tuple().unwrap();
429 let closure_arg_count = untupled_args.elem_count();
430 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
431 let nth_field = tupled_arg.clone().project(
432 ProjectionElem::Field(
433 FieldProjKind::Tuple(closure_arg_count),
434 FieldId::new(i),
435 ),
436 ty,
437 );
438 mk_stt(RawStatement::Assign(
439 locals.place_for_var(LocalId::new(i + 3)),
440 Rvalue::Use(Operand::Move(nth_field)),
441 ))
442 });
443 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
444
445 Ok(body)
446 }
447 Ok(Err(Opaque)) => Err(Opaque),
448 Err(_) => Err(Opaque),
449 }
450 }
451 (FnOnce, Fn | FnMut) => {
461 let hax::FullDefKind::Closure {
463 once_shim: Some(body),
464 ..
465 } = &def.kind
466 else {
467 panic!("missing shim for closure")
468 };
469 let mut bt_ctx = BodyTransCtx::new(&mut self);
470 match bt_ctx.translate_body(span, body, &def.source_text) {
471 Ok(Ok(body)) => Ok(body),
472 Ok(Err(Opaque)) => Err(Opaque),
473 Err(_) => Err(Opaque),
474 }
475 }
476 (FnMut, Fn) => {
483 let fun_id = self.register_closure_method_decl_id(span, def.def_id(), closure_kind);
484 let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
485 let fn_op = FnOperand::Regular(FnPtr {
488 func: FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id.clone())).into(),
489 generics: Box::new(impl_ref.generics.concat(&GenericArgs {
490 regions: vec![Region::Erased].into(),
491 ..GenericArgs::empty()
492 })),
493 });
494
495 let mut locals = Locals {
496 arg_count: 2,
497 locals: Vector::new(),
498 };
499 let mut statements = vec![];
500 let mut blocks = Vector::default();
501
502 let output = locals.new_var(None, signature.output.clone());
503 let state = locals.new_var(Some("state".to_string()), signature.inputs[0].clone());
504 let args = locals.new_var(Some("args".to_string()), signature.inputs[1].clone());
505 let deref_state = state.deref();
506 let reborrow_ty =
507 TyKind::Ref(Region::Erased, deref_state.ty.clone(), RefKind::Shared).into_ty();
508 let reborrow = locals.new_var(None, reborrow_ty);
509
510 statements.push(mk_stt(RawStatement::Assign(
511 reborrow.clone(),
512 Rvalue::Ref(deref_state, BorrowKind::Shared),
513 )));
514
515 let start_block = blocks.reserve_slot();
516 let ret_block = blocks.push(mk_block(vec![], RawTerminator::Return));
517 let unwind_block = blocks.push(mk_block(vec![], RawTerminator::UnwindResume));
518 let call = RawTerminator::Call {
519 target: ret_block,
520 call: Call {
521 func: fn_op,
522 args: vec![Operand::Move(reborrow), Operand::Move(args)],
523 dest: output,
524 },
525 on_unwind: unwind_block,
526 };
527 blocks.set_slot(start_block, mk_block(statements, call));
528
529 let body: ExprBody = GExprBody {
530 span,
531 locals,
532 comments: vec![],
533 body: blocks,
534 };
535 Ok(Body::Unstructured(body))
536 }
537 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
538 panic!(
539 "Can't make a closure body for a more restrictive kind \
540 than the closure kind"
541 )
542 }
543 })
544 }
545
546 #[tracing::instrument(skip(self, item_meta))]
549 pub fn translate_closure_method(
550 mut self,
551 def_id: FunDeclId,
552 item_meta: ItemMeta,
553 def: &hax::FullDef,
554 target_kind: ClosureKind,
555 ) -> Result<FunDecl, Error> {
556 let span = item_meta.span;
557 let hax::FullDefKind::Closure { args, .. } = &def.kind else {
558 unreachable!()
559 };
560
561 trace!("About to translate closure:\n{:?}", def.def_id);
562
563 self.translate_def_generics(span, def)?;
564 assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
566 self.innermost_binder_mut()
567 .push_params_from_binder(args.fn_sig.rebind(()))?;
568
569 let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
570 let implemented_trait = self.translate_closure_trait_ref(span, args, target_kind)?;
571 let kind = ItemKind::TraitImpl {
572 impl_ref,
573 trait_ref: implemented_trait,
574 item_name: TraitItemName(target_kind.method_name().to_owned()),
575 reuses_default: false,
576 };
577
578 let signature = self.translate_closure_method_sig(def, span, args, target_kind)?;
580
581 let body = if item_meta.opacity.with_private_contents().is_opaque() {
582 Err(Opaque)
583 } else {
584 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
585 };
586
587 Ok(FunDecl {
588 def_id,
589 item_meta,
590 signature,
591 kind,
592 is_global_initializer: None,
593 body,
594 })
595 }
596
597 #[tracing::instrument(skip(self, item_meta))]
598 pub fn translate_closure_trait_impl(
599 mut self,
600 def_id: TraitImplId,
601 item_meta: ItemMeta,
602 def: &hax::FullDef,
603 target_kind: ClosureKind,
604 ) -> Result<TraitImpl, Error> {
605 let span = item_meta.span;
606 let hax::FullDefKind::Closure {
607 args,
608 fn_once_impl,
609 fn_mut_impl,
610 fn_impl,
611 ..
612 } = &def.kind
613 else {
614 unreachable!()
615 };
616
617 self.translate_def_generics(span, def)?;
618 assert!(self.innermost_binder_mut().bound_region_vars.is_empty());
620 self.innermost_binder_mut()
621 .push_params_from_binder(args.fn_sig.rebind(()))?;
622
623 let vimpl = match target_kind {
625 ClosureKind::FnOnce => fn_once_impl,
626 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
627 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
628 };
629 let implemented_trait = self.translate_trait_ref(span, &vimpl.trait_pred.trait_ref)?;
630 let fn_trait = implemented_trait.id;
631
632 let mut parent_trait_refs =
633 self.translate_trait_impl_exprs(span, &vimpl.implied_impl_exprs)?;
634 let mut types = vec![];
635 for (output, impl_exprs) in &vimpl.types {
636 let output = self.translate_ty(span, output)?;
638 types.push((TraitItemName("Output".into()), output.clone()));
639 let trait_refs = self.translate_trait_impl_exprs(span, impl_exprs)?;
640 parent_trait_refs.extend(trait_refs);
641 }
642
643 let call_fn_id = self.register_closure_method_decl_id(span, &def.def_id, target_kind);
645 let call_fn_name = TraitItemName(target_kind.method_name().to_string());
646 let call_fn_binder = {
647 let mut method_params = GenericParams::empty();
648 match target_kind {
649 ClosureKind::FnOnce => {}
650 ClosureKind::FnMut | ClosureKind::Fn => {
651 method_params
652 .regions
653 .push_with(|index| RegionVar { index, name: None });
654 }
655 };
656
657 let generics = self
658 .outermost_binder()
659 .params
660 .identity_args_at_depth(DeBruijnId::one())
661 .concat(&method_params.identity_args_at_depth(DeBruijnId::zero()));
662 Binder::new(
663 BinderKind::TraitMethod(fn_trait, call_fn_name.clone()),
664 method_params,
665 FunDeclRef {
666 id: call_fn_id,
667 generics: Box::new(generics),
668 },
669 )
670 };
671
672 let self_generics = self.into_generics();
673
674 Ok(TraitImpl {
675 def_id,
676 item_meta,
677 impl_trait: implemented_trait,
678 generics: self_generics,
679 parent_trait_refs,
680 type_clauses: vec![],
681 consts: vec![],
682 types,
683 methods: vec![(call_fn_name, call_fn_binder)],
684 })
685 }
686
687 #[tracing::instrument(skip(self, item_meta))]
690 pub fn translate_stateless_closure_as_fn(
691 mut self,
692 def_id: FunDeclId,
693 item_meta: ItemMeta,
694 def: &hax::FullDef,
695 ) -> Result<FunDecl, Error> {
696 let span = item_meta.span;
697 let hax::FullDefKind::Closure { args: closure, .. } = &def.kind else {
698 unreachable!()
699 };
700
701 trace!("About to translate closure as fn:\n{:?}", def.def_id);
702
703 assert!(
704 closure.upvar_tys.is_empty(),
705 "Only stateless closures can be translated as functions"
706 );
707
708 self.translate_def_generics(span, def)?;
709 assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
711 self.innermost_binder_mut()
712 .push_params_from_binder(closure.fn_sig.rebind(()))?;
713
714 let mut signature =
716 self.translate_closure_method_sig(def, span, closure, ClosureKind::FnOnce)?;
717 let state_ty = signature.inputs.remove(0);
718 let args_tuple_ty = signature.inputs.remove(0);
719 signature.inputs = args_tuple_ty.as_tuple().unwrap().iter().cloned().collect();
720
721 let body = if item_meta.opacity.with_private_contents().is_opaque() {
722 Err(Opaque)
723 } else {
724 let mk_stt = |content| Statement::new(span, content);
732 let mk_block = |statements, terminator| -> BlockData {
733 BlockData {
734 statements,
735 terminator: Terminator::new(span, terminator),
736 }
737 };
738 let fun_id =
739 self.register_closure_method_decl_id(span, def.def_id(), ClosureKind::FnOnce);
740 let impl_ref = self.translate_closure_impl_ref(span, closure, ClosureKind::FnOnce)?;
741 let fn_op = FnOperand::Regular(FnPtr {
742 func: FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id.clone())).into(),
743 generics: impl_ref.generics.clone(),
744 });
745
746 let mut locals = Locals {
747 arg_count: signature.inputs.len(),
748 locals: Vector::new(),
749 };
750 let mut statements = vec![];
751 let mut blocks = Vector::default();
752
753 let output = locals.new_var(None, signature.output.clone());
754 let args: Vec<Place> = signature
755 .inputs
756 .iter()
757 .enumerate()
758 .map(|(i, ty)| locals.new_var(Some(format!("arg{}", i + 1)), ty.clone()))
759 .collect();
760 let args_tupled = locals.new_var(Some("args".to_string()), args_tuple_ty.clone());
761 let state = locals.new_var(Some("state".to_string()), state_ty.clone());
762
763 statements.push(mk_stt(RawStatement::Assign(
764 args_tupled.clone(),
765 Rvalue::Aggregate(
766 AggregateKind::Adt(args_tuple_ty.as_adt().unwrap().clone(), None, None),
767 args.into_iter().map(Operand::Move).collect(),
768 ),
769 )));
770
771 let state_ty_adt = state_ty.as_adt().unwrap();
772 statements.push(mk_stt(RawStatement::Assign(
773 state.clone(),
774 Rvalue::Aggregate(AggregateKind::Adt(state_ty_adt.clone(), None, None), vec![]),
775 )));
776
777 let start_block = blocks.reserve_slot();
778 let ret_block = blocks.push(mk_block(vec![], RawTerminator::Return));
779 let unwind_block = blocks.push(mk_block(vec![], RawTerminator::UnwindResume));
780 let call = RawTerminator::Call {
781 target: ret_block,
782 call: Call {
783 func: fn_op,
784 args: vec![Operand::Move(state), Operand::Move(args_tupled)],
785 dest: output,
786 },
787 on_unwind: unwind_block,
788 };
789 blocks.set_slot(start_block, mk_block(statements, call));
790
791 let body: ExprBody = GExprBody {
792 span,
793 locals,
794 comments: vec![],
795 body: blocks,
796 };
797 Ok(Body::Unstructured(body))
798 };
799
800 Ok(FunDecl {
801 def_id,
802 item_meta,
803 signature,
804 kind: ItemKind::TopLevel,
805 is_global_initializer: None,
806 body,
807 })
808 }
809}