1use std::mem;
42
43use super::translate_crate::TransItemSourceKind;
44use super::translate_ctx::*;
45use charon_lib::ast::ullbc_ast_utils::BodyBuilder;
46use charon_lib::ast::*;
47use charon_lib::ids::IndexVec;
48use charon_lib::ullbc_ast::*;
49use itertools::Itertools;
50
51pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
52 match kind {
53 hax::ClosureKind::Fn => ClosureKind::Fn,
54 hax::ClosureKind::FnMut => ClosureKind::FnMut,
55 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
56 }
57}
58
59impl ItemTransCtx<'_, '_> {
64 fn by_ref_upvar_regions(&self, closure: &hax::ClosureArgs) -> IndexMap<RegionId, Region> {
67 if self.item_src.def_id() == &closure.item.def_id {
68 self.outermost_binder()
70 .by_ref_upvar_regions
71 .iter()
72 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r)))
73 .collect()
74 } else {
75 closure
76 .iter_upvar_borrows()
77 .map(|_| Region::Erased)
78 .collect()
79 }
80 }
81
82 fn closure_late_regions(&self, closure: &hax::ClosureArgs) -> IndexMap<RegionId, Region> {
85 if self.item_src.def_id() == &closure.item.def_id {
86 self.outermost_binder()
88 .bound_region_vars
89 .iter()
90 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r)))
91 .collect()
92 } else {
93 closure
94 .fn_sig
95 .bound_vars
96 .iter()
97 .map(|_| Region::Erased)
98 .collect()
99 }
100 }
101
102 fn closure_method_regions(
105 &self,
106 closure: &hax::ClosureArgs,
107 target_kind: ClosureKind,
108 ) -> IndexMap<RegionId, Region> {
109 if self.item_src.def_id() == &closure.item.def_id {
110 self.outermost_binder()
112 .closure_call_method_region
113 .iter()
114 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r)))
115 .collect()
116 } else {
117 match target_kind {
118 ClosureKind::FnOnce => IndexMap::new(),
119 ClosureKind::FnMut | ClosureKind::Fn => [Region::Erased].into_iter().collect(),
120 }
121 }
122 }
123
124 fn translate_closure_bound_ref_with_upvars(
128 &mut self,
129 span: Span,
130 closure: &hax::ClosureArgs,
131 kind: TransItemSourceKind,
132 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
133 let upvar_binder = hax::Binder {
135 value: (),
136 bound_vars: closure
137 .iter_upvar_borrows()
138 .map(|_| hax::BoundVariableKind::Region(hax::BoundRegionKind::Anon))
139 .collect(),
140 };
141 let mut tref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
142 if self.item_src.def_id() == &closure.item.def_id && !self.monomorphize() {
144 let depth = self.binding_levels.depth();
145 for (rid, r) in tref.generics.regions.iter_mut_indexed() {
149 *r = Region::Var(DeBruijnVar::bound(depth, rid));
150 }
151 }
152 self.translate_region_binder(span, &upvar_binder, |ctx, _| {
153 let mut tref = tref.move_under_binder();
154 tref.generics.regions.extend(
155 ctx.innermost_binder()
156 .params
157 .identity_args()
158 .regions
159 .into_iter(),
160 );
161 Ok(tref)
162 })
163 }
164
165 fn translate_closure_bound_ref_with_late_bound(
170 &mut self,
171 span: Span,
172 closure: &hax::ClosureArgs,
173 kind: TransItemSourceKind,
174 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
175 let inner_ref = self
176 .translate_closure_bound_ref_with_upvars(span, closure, kind)?
177 .apply(self.by_ref_upvar_regions(closure));
178 self.translate_region_binder(span, &closure.fn_sig, |ctx, _| {
179 let mut inner_ref = inner_ref.move_under_binder();
180 inner_ref.generics.regions.extend(
181 ctx.innermost_binder()
182 .params
183 .identity_args()
184 .regions
185 .into_iter(),
186 );
187 Ok(inner_ref)
188 })
189 }
190
191 fn translate_closure_bound_ref_with_method_bound(
196 &mut self,
197 span: Span,
198 closure: &hax::ClosureArgs,
199 kind: TransItemSourceKind,
200 target_kind: ClosureKind,
201 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
202 let mut dref = self
203 .translate_closure_bound_ref_with_late_bound(span, closure, kind)?
204 .apply(self.closure_late_regions(closure))
205 .move_under_binder();
206 let mut regions = IndexMap::new();
207 match target_kind {
208 ClosureKind::FnOnce => {}
209 ClosureKind::FnMut | ClosureKind::Fn => {
210 let rid = regions.push_with(|index| RegionParam { index, name: None });
211 dref.generics
212 .regions
213 .push(Region::Var(DeBruijnVar::new_at_zero(rid)));
214 }
215 }
216 Ok(RegionBinder {
217 regions,
218 skip_binder: dref,
219 })
220 }
221}
222
223impl ItemTransCtx<'_, '_> {
224 pub fn translate_closure_type_ref(
226 &mut self,
227 span: Span,
228 closure: &hax::ClosureArgs,
229 ) -> Result<TypeDeclRef, Error> {
230 let kind = TransItemSourceKind::Type;
231 let bound_dref = self.translate_closure_bound_ref_with_upvars(span, closure, kind)?;
232 let dref = bound_dref.apply(self.by_ref_upvar_regions(closure));
233 Ok(dref.try_into().unwrap())
234 }
235
236 pub fn translate_stateless_closure_as_fn_ref(
240 &mut self,
241 span: Span,
242 closure: &hax::ClosureArgs,
243 ) -> Result<RegionBinder<FunDeclRef>, Error> {
244 let kind = TransItemSourceKind::ClosureAsFnCast;
245 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
246 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
247 }
248
249 pub fn translate_closure_bound_impl_ref(
253 &mut self,
254 span: Span,
255 closure: &hax::ClosureArgs,
256 target_kind: ClosureKind,
257 ) -> Result<RegionBinder<TraitImplRef>, Error> {
258 let kind = TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind));
259 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
260 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
261 }
262
263 pub fn translate_closure_impl_ref(
265 &mut self,
266 span: Span,
267 closure: &hax::ClosureArgs,
268 target_kind: ClosureKind,
269 ) -> Result<TraitImplRef, Error> {
270 Ok(self
271 .translate_closure_bound_impl_ref(span, closure, target_kind)?
272 .apply(self.closure_late_regions(closure)))
273 }
274
275 pub fn translate_closure_info(
276 &mut self,
277 span: Span,
278 args: &hax::ClosureArgs,
279 ) -> Result<ClosureInfo, Error> {
280 use ClosureKind::*;
281 let kind = translate_closure_kind(&args.kind);
282
283 let fn_once_impl = self.translate_closure_bound_impl_ref(span, args, FnOnce)?;
284 let fn_mut_impl = if matches!(kind, FnMut | Fn) {
285 Some(self.translate_closure_bound_impl_ref(span, args, FnMut)?)
286 } else {
287 None
288 };
289 let fn_impl = if matches!(kind, Fn) {
290 Some(self.translate_closure_bound_impl_ref(span, args, Fn)?)
291 } else {
292 None
293 };
294 let signature = self.translate_poly_fun_sig(span, &args.fn_sig)?;
295 Ok(ClosureInfo {
296 kind,
297 fn_once_impl,
298 fn_mut_impl,
299 fn_impl,
300 signature,
301 })
302 }
303
304 pub fn get_closure_state_ty(
305 &mut self,
306 span: Span,
307 args: &hax::ClosureArgs,
308 ) -> Result<Ty, Error> {
309 let tref = self.translate_closure_type_ref(span, args)?;
310 Ok(TyKind::Adt(tref).into_ty())
311 }
312
313 pub fn translate_closure_adt(
314 &mut self,
315 _trans_id: TypeDeclId,
316 span: Span,
317 args: &hax::ClosureArgs,
318 ) -> Result<TypeDeclKind, Error> {
319 let mut by_ref_upvar_regions = self
320 .the_only_binder()
321 .by_ref_upvar_regions
322 .clone()
323 .into_iter();
324 let fields: IndexVec<FieldId, Field> = args
325 .upvar_tys
326 .iter()
327 .map(|ty| -> Result<Field, Error> {
328 let mut ty = self.translate_ty(span, ty)?;
329 if let TyKind::Ref(Region::Erased, deref_ty, kind) = ty.kind() {
331 let region_id = by_ref_upvar_regions.next().unwrap();
332 ty = TyKind::Ref(
333 Region::Var(DeBruijnVar::new_at_zero(region_id)),
334 deref_ty.clone(),
335 *kind,
336 )
337 .into_ty();
338 }
339 Ok(Field {
340 span,
341 attr_info: AttrInfo::dummy_private(),
342 name: None,
343 ty,
344 })
345 })
346 .try_collect()?;
347 Ok(TypeDeclKind::Struct(fields))
348 }
349
350 fn translate_closure_method_sig(
353 &mut self,
354 def: &hax::FullDef,
355 span: Span,
356 args: &hax::ClosureArgs,
357 target_kind: ClosureKind,
358 ) -> Result<RegionBinder<FunSig>, Error> {
359 let signature = &args.fn_sig;
360 trace!(
361 "signature of closure {:?}:\n{:?}",
362 def.def_id(),
363 signature.value,
364 );
365
366 let mut bound_regions = IndexMap::new();
367 let mut fun_sig = self
368 .translate_fun_sig(span, signature.hax_skip_binder_ref())?
369 .move_under_binder();
370 let state_ty = self.get_closure_state_ty(span, args)?.move_under_binder();
371
372 let state_ty = match target_kind {
374 ClosureKind::FnOnce => state_ty,
375 ClosureKind::Fn | ClosureKind::FnMut => {
376 let rid = bound_regions.push_with(|index| RegionParam { index, name: None });
377 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
378 let mutability = if target_kind == ClosureKind::Fn {
379 RefKind::Shared
380 } else {
381 RefKind::Mut
382 };
383 TyKind::Ref(r, state_ty, mutability).into_ty()
384 }
385 };
386
387 let input_tys: Vec<Ty> = mem::take(&mut fun_sig.inputs);
389 fun_sig.inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
391
392 Ok(RegionBinder {
393 regions: bound_regions,
394 skip_binder: fun_sig,
395 })
396 }
397
398 fn translate_closure_method_body(
399 &mut self,
400 span: Span,
401 def: &hax::FullDef,
402 target_kind: ClosureKind,
403 args: &hax::ClosureArgs,
404 signature: &FunSig,
405 ) -> Result<Body, Error> {
406 use ClosureKind::*;
407 let closure_kind = translate_closure_kind(&args.kind);
408 Ok(match (target_kind, closure_kind) {
409 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
410 let mut body = self.translate_def_body(span, def);
412 let Body::Unstructured(GExprBody {
421 locals,
422 body: blocks,
423 ..
424 }) = &mut body
425 else {
426 return Ok(body);
427 };
428
429 blocks.dyn_visit_mut(|local: &mut LocalId| {
430 if local.index() >= 2 {
431 *local += 1;
432 }
433 });
434
435 let mut old_locals = mem::take(&mut locals.locals).into_iter();
436 locals.arg_count = 2;
437 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg =
440 locals.new_var(Some("tupled_args".to_string()), signature.inputs[1].clone());
441 locals.locals.extend(old_locals.map(|mut l| {
442 l.index += 1;
443 l
444 }));
445
446 let untupled_args = signature.inputs[1].as_tuple().unwrap();
447 let closure_arg_count = untupled_args.elem_count();
448 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
449 let nth_field = tupled_arg.clone().project(
450 ProjectionElem::Field(
451 FieldProjKind::Tuple(closure_arg_count),
452 FieldId::new(i),
453 ),
454 ty,
455 );
456 let local_id = LocalId::new(i + 3);
457 Statement::new(
458 span,
459 StatementKind::Assign(
460 locals.place_for_var(local_id),
461 Rvalue::Use(Operand::Move(nth_field)),
462 ),
463 )
464 });
465 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
466
467 body
468 }
469 (FnOnce, Fn | FnMut) => {
479 let hax::FullDefKind::Closure {
481 once_shim: Some(body),
482 ..
483 } = &def.kind
484 else {
485 panic!("missing shim for closure")
486 };
487 self.translate_body(span, body, &def.source_text)
488 }
489 (FnMut, Fn) => {
496 let fun_id: FunDeclId = self.register_item(
497 span,
498 def.this(),
499 TransItemSourceKind::ClosureMethod(closure_kind),
500 );
501 let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
502 let fn_op = FnOperand::Regular(FnPtr::new(
505 fun_id.into(),
506 impl_ref.generics.concat(&GenericArgs {
507 regions: vec![Region::Erased].into(),
508 ..GenericArgs::empty()
509 }),
510 ));
511
512 let mut builder = BodyBuilder::new(span, 2);
513
514 let output = builder.new_var(None, signature.output.clone());
515 let state = builder.new_var(Some("state".to_string()), signature.inputs[0].clone());
516 let args = builder.new_var(Some("args".to_string()), signature.inputs[1].clone());
517 let deref_state = state.deref();
518 let reborrow_ty =
519 TyKind::Ref(Region::Erased, deref_state.ty.clone(), RefKind::Shared).into_ty();
520 let reborrow = builder.new_var(None, reborrow_ty);
521
522 builder.push_statement(StatementKind::Assign(
523 reborrow.clone(),
524 Rvalue::Ref {
525 place: deref_state,
526 kind: BorrowKind::Shared,
527 ptr_metadata: Operand::mk_const_unit(),
529 },
530 ));
531
532 builder.call(Call {
533 func: fn_op,
534 args: vec![Operand::Move(reborrow), Operand::Move(args)],
535 dest: output,
536 });
537
538 Body::Unstructured(builder.build())
539 }
540 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
541 panic!(
542 "Can't make a closure body for a more restrictive kind \
543 than the closure kind"
544 )
545 }
546 })
547 }
548
549 #[tracing::instrument(skip(self, item_meta))]
552 pub fn translate_closure_method(
553 mut self,
554 def_id: FunDeclId,
555 item_meta: ItemMeta,
556 def: &hax::FullDef,
557 target_kind: ClosureKind,
558 ) -> Result<FunDecl, Error> {
559 let span = item_meta.span;
560 let hax::FullDefKind::Closure {
561 args,
562 fn_once_impl,
563 fn_mut_impl,
564 fn_impl,
565 ..
566 } = &def.kind
567 else {
568 unreachable!()
569 };
570
571 let vimpl = match target_kind {
573 ClosureKind::FnOnce => fn_once_impl,
574 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
575 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
576 };
577 let implemented_trait = self.translate_trait_predicate(span, &vimpl.trait_pred)?;
578
579 let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
580 let src = ItemSource::TraitImpl {
581 impl_ref,
582 trait_ref: implemented_trait,
583 item_name: TraitItemName(target_kind.method_name().into()),
584 reuses_default: false,
585 };
586
587 let signature = self
589 .translate_closure_method_sig(def, span, args, target_kind)?
590 .apply(self.closure_method_regions(args, target_kind));
591
592 let body = if item_meta.opacity.with_private_contents().is_opaque() {
593 Body::Opaque
594 } else {
595 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
596 };
597
598 Ok(FunDecl {
599 def_id,
600 item_meta,
601 generics: self.into_generics(),
602 signature,
603 src,
604 is_global_initializer: None,
605 body,
606 })
607 }
608
609 #[tracing::instrument(skip(self, item_meta))]
610 pub fn translate_closure_trait_impl(
611 mut self,
612 def_id: TraitImplId,
613 item_meta: ItemMeta,
614 def: &hax::FullDef,
615 target_kind: ClosureKind,
616 ) -> Result<TraitImpl, Error> {
617 let span = item_meta.span;
618 let hax::FullDefKind::Closure {
619 args,
620 fn_once_impl,
621 fn_mut_impl,
622 fn_impl,
623 ..
624 } = def.kind()
625 else {
626 unreachable!()
627 };
628
629 let vimpl = match target_kind {
631 ClosureKind::FnOnce => fn_once_impl,
632 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
633 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
634 };
635 let mut timpl = self.translate_virtual_trait_impl(def_id, item_meta, vimpl)?;
636
637 let call_fn_name = TraitItemName(target_kind.method_name().into());
639 let call_fn_binder = {
640 let kind = TransItemSourceKind::ClosureMethod(target_kind);
641 let bound_method_ref: RegionBinder<DeclRef<ItemId>> =
642 self.translate_closure_bound_ref_with_method_bound(span, args, kind, target_kind)?;
643 let params = GenericParams {
644 regions: bound_method_ref.regions,
645 ..GenericParams::empty()
646 };
647 let fn_decl_ref: FunDeclRef = bound_method_ref.skip_binder.try_into().unwrap();
648 Binder::new(
649 BinderKind::TraitMethod(timpl.impl_trait.id, call_fn_name),
650 params,
651 fn_decl_ref,
652 )
653 };
654 timpl.methods.push((call_fn_name, call_fn_binder));
655
656 Ok(timpl)
657 }
658
659 #[tracing::instrument(skip(self, item_meta))]
662 pub fn translate_stateless_closure_as_fn(
663 mut self,
664 def_id: FunDeclId,
665 item_meta: ItemMeta,
666 def: &hax::FullDef,
667 ) -> Result<FunDecl, Error> {
668 let span = item_meta.span;
669 let hax::FullDefKind::Closure { args: closure, .. } = &def.kind else {
670 unreachable!()
671 };
672
673 trace!("About to translate closure as fn:\n{:?}", def.def_id());
674
675 assert!(
676 closure.upvar_tys.is_empty(),
677 "Only stateless closures can be translated as functions"
678 );
679
680 let signature = self.translate_fun_sig(span, closure.fn_sig.hax_skip_binder_ref())?;
682 let state_ty = self.get_closure_state_ty(span, closure)?;
683
684 let body = if item_meta.opacity.with_private_contents().is_opaque() {
685 Body::Opaque
686 } else {
687 let fun_id: FunDeclId = self.register_item(
695 span,
696 def.this(),
697 TransItemSourceKind::ClosureMethod(ClosureKind::FnOnce),
698 );
699 let impl_ref = self.translate_closure_impl_ref(span, closure, ClosureKind::FnOnce)?;
700 let fn_op = FnOperand::Regular(FnPtr::new(fun_id.into(), impl_ref.generics.clone()));
701
702 let mut builder = BodyBuilder::new(span, signature.inputs.len());
703
704 let output = builder.new_var(None, signature.output.clone());
705 let args: Vec<Place> = signature
706 .inputs
707 .iter()
708 .enumerate()
709 .map(|(i, ty)| builder.new_var(Some(format!("arg{}", i + 1)), ty.clone()))
710 .collect();
711 let args_tupled_ty = Ty::mk_tuple(signature.inputs.clone());
712 let args_tupled = builder.new_var(Some("args".to_string()), args_tupled_ty.clone());
713 let state = builder.new_var(Some("state".to_string()), state_ty.clone());
714
715 builder.push_statement(StatementKind::Assign(
716 args_tupled.clone(),
717 Rvalue::Aggregate(
718 AggregateKind::Adt(args_tupled_ty.as_adt().unwrap().clone(), None, None),
719 args.into_iter().map(Operand::Move).collect(),
720 ),
721 ));
722
723 let state_ty_adt = state_ty.as_adt().unwrap();
724 builder.push_statement(StatementKind::Assign(
725 state.clone(),
726 Rvalue::Aggregate(AggregateKind::Adt(state_ty_adt.clone(), None, None), vec![]),
727 ));
728
729 builder.call(Call {
730 func: fn_op,
731 args: vec![Operand::Move(state), Operand::Move(args_tupled)],
732 dest: output,
733 });
734
735 Body::Unstructured(builder.build())
736 };
737
738 Ok(FunDecl {
739 def_id,
740 item_meta,
741 generics: self.into_generics(),
742 signature,
743 src: ItemSource::TopLevel,
744 is_global_initializer: None,
745 body,
746 })
747 }
748}