1use itertools::Itertools;
42use std::mem;
43
44use super::translate_crate::TransItemSourceKind;
45use super::translate_ctx::*;
46use charon_lib::ast::ullbc_ast_utils::BodyBuilder;
47use charon_lib::ast::*;
48use charon_lib::ids::IndexVec;
49use charon_lib::ullbc_ast::*;
50
51pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
52 match kind {
53 hax::ClosureKind::Fn => ClosureKind::Fn,
54 hax::ClosureKind::FnMut => ClosureKind::FnMut,
55 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
56 }
57}
58
59impl ItemTransCtx<'_, '_> {
64 fn translate_closure_bound_ref_with_late_bound(
67 &mut self,
68 span: Span,
69 closure: &hax::ClosureArgs,
70 kind: TransItemSourceKind,
71 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
72 if !matches!(
73 kind,
74 TransItemSourceKind::TraitImpl(..) | TransItemSourceKind::ClosureAsFnCast
75 ) {
76 raise_error!(
77 self,
78 span,
79 "Called `translate_closure_bound_ref_with_late_bound` on a `{kind:?}`; \
80 use `translate_closure_ref_with_upvars` \
81 or `translate_closure_bound_ref_with_method_bound` instead"
82 )
83 }
84 let dref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
85 self.translate_region_binder(span, &closure.fn_sig, |ctx, _| {
86 let mut dref = dref.move_under_binder();
87 for (a, b) in dref.generics.regions.iter_mut().rev().zip(
89 ctx.innermost_binder()
90 .params
91 .identity_args()
92 .regions
93 .into_iter()
94 .rev(),
95 ) {
96 *a = b;
97 }
98 Ok(dref)
99 })
100 }
101
102 fn translate_closure_bound_ref_with_method_bound(
106 &mut self,
107 span: Span,
108 closure: &hax::ClosureArgs,
109 kind: TransItemSourceKind,
110 target_kind: ClosureKind,
111 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
112 if !matches!(kind, TransItemSourceKind::ClosureMethod(..)) {
113 raise_error!(
114 self,
115 span,
116 "Called `translate_closure_bound_ref_with_method_bound` on a `{kind:?}`; \
117 use `translate_closure_ref_with_upvars` \
118 or `translate_closure_bound_ref_with_late_bound` instead"
119 )
120 }
121 let dref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
122 let mut dref = dref.move_under_binder();
123 let mut regions = IndexMap::new();
124 match target_kind {
125 ClosureKind::FnOnce => {}
126 ClosureKind::FnMut | ClosureKind::Fn => {
127 let rid = regions.push_with(|index| RegionParam { index, name: None });
128 *dref.generics.regions.iter_mut().last().unwrap() =
129 Region::Var(DeBruijnVar::new_at_zero(rid));
130 }
131 }
132 Ok(RegionBinder {
133 regions,
134 skip_binder: dref,
135 })
136 }
137}
138
139impl ItemTransCtx<'_, '_> {
140 pub fn translate_closure_type_ref(
142 &mut self,
143 span: Span,
144 closure: &hax::ClosureArgs,
145 ) -> Result<TypeDeclRef, Error> {
146 self.translate_item(span, &closure.item, TransItemSourceKind::Type)
147 }
148
149 pub fn translate_stateless_closure_as_fn_ref(
153 &mut self,
154 span: Span,
155 closure: &hax::ClosureArgs,
156 ) -> Result<RegionBinder<FunDeclRef>, Error> {
157 let kind = TransItemSourceKind::ClosureAsFnCast;
158 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
159 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
160 }
161
162 pub fn translate_closure_bound_impl_ref(
166 &mut self,
167 span: Span,
168 closure: &hax::ClosureArgs,
169 target_kind: ClosureKind,
170 ) -> Result<RegionBinder<TraitImplRef>, Error> {
171 let kind = TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind));
172 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
173 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
174 }
175
176 pub fn translate_closure_impl_ref(
178 &mut self,
179 span: Span,
180 closure: &hax::ClosureArgs,
181 target_kind: ClosureKind,
182 ) -> Result<TraitImplRef, Error> {
183 self.translate_item(
184 span,
185 &closure.item,
186 TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind)),
187 )
188 }
189
190 pub fn translate_closure_info(
191 &mut self,
192 span: Span,
193 args: &hax::ClosureArgs,
194 ) -> Result<ClosureInfo, Error> {
195 use ClosureKind::*;
196 let kind = translate_closure_kind(&args.kind);
197
198 let fn_once_impl = self.translate_closure_bound_impl_ref(span, args, FnOnce)?;
199 let fn_mut_impl = if matches!(kind, FnMut | Fn) {
200 Some(self.translate_closure_bound_impl_ref(span, args, FnMut)?)
201 } else {
202 None
203 };
204 let fn_impl = if matches!(kind, Fn) {
205 Some(self.translate_closure_bound_impl_ref(span, args, Fn)?)
206 } else {
207 None
208 };
209 let signature = self.translate_poly_fun_sig(span, &args.fn_sig)?;
210 Ok(ClosureInfo {
211 kind,
212 fn_once_impl,
213 fn_mut_impl,
214 fn_impl,
215 signature,
216 })
217 }
218
219 pub fn get_closure_state_ty(
220 &mut self,
221 span: Span,
222 args: &hax::ClosureArgs,
223 ) -> Result<Ty, Error> {
224 let tref = self.translate_closure_type_ref(span, args)?;
225 Ok(TyKind::Adt(tref).into_ty())
226 }
227
228 pub fn translate_closure_upvar_tys(
232 &mut self,
233 span: Span,
234 args: &hax::ClosureArgs,
235 ) -> Result<IndexVec<FieldId, Ty>, Error> {
236 let upvar_tys: IndexVec<FieldId, Ty> = args
237 .upvar_tys
238 .iter()
239 .map(|ty| self.translate_ty(span, ty))
240 .try_collect()?;
241 let upvar_tys = upvar_tys.replace_erased_regions(|| {
242 let region_id = self.the_only_binder_mut().push_upvar_region();
243 Region::Var(DeBruijnVar::new_at_zero(region_id))
244 });
245 Ok(upvar_tys)
246 }
247
248 pub fn translate_closure_adt(
249 &mut self,
250 span: Span,
251 _args: &hax::ClosureArgs,
252 ) -> Result<TypeDeclKind, Error> {
253 let fields: IndexVec<FieldId, Field> = self
254 .the_only_binder()
255 .closure_upvar_tys
256 .as_ref()
257 .unwrap()
258 .iter()
259 .cloned()
260 .map(|ty| Field {
261 span,
262 attr_info: AttrInfo::dummy_private(),
263 name: None,
264 ty: ty,
265 })
266 .collect();
267 Ok(TypeDeclKind::Struct(fields))
268 }
269
270 fn translate_closure_method_sig(
273 &mut self,
274 def: &hax::FullDef,
275 span: Span,
276 args: &hax::ClosureArgs,
277 target_kind: ClosureKind,
278 ) -> Result<RegionBinder<FunSig>, Error> {
279 let signature = &args.fn_sig;
280 trace!(
281 "signature of closure {:?}:\n{:?}",
282 def.def_id(),
283 signature.value,
284 );
285
286 let mut bound_regions = IndexMap::new();
287 let mut fun_sig = self
288 .translate_fun_sig(span, signature.hax_skip_binder_ref())?
289 .move_under_binder();
290 let state_ty = self.get_closure_state_ty(span, args)?.move_under_binder();
291
292 let state_ty = match target_kind {
294 ClosureKind::FnOnce => state_ty,
295 ClosureKind::Fn | ClosureKind::FnMut => {
296 let rid = bound_regions.push_with(|index| RegionParam { index, name: None });
297 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
298 let mutability = if target_kind == ClosureKind::Fn {
299 RefKind::Shared
300 } else {
301 RefKind::Mut
302 };
303 TyKind::Ref(r, state_ty, mutability).into_ty()
304 }
305 };
306
307 let input_tys: Vec<Ty> = mem::take(&mut fun_sig.inputs);
309 fun_sig.inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
311
312 Ok(RegionBinder {
313 regions: bound_regions,
314 skip_binder: fun_sig,
315 })
316 }
317
318 fn translate_closure_method_body(
319 &mut self,
320 span: Span,
321 def: &hax::FullDef,
322 target_kind: ClosureKind,
323 args: &hax::ClosureArgs,
324 signature: &FunSig,
325 ) -> Result<Body, Error> {
326 use ClosureKind::*;
327 let closure_kind = translate_closure_kind(&args.kind);
328 Ok(match (target_kind, closure_kind) {
329 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
330 let mut body = self.translate_def_body(span, def);
332 let Body::Unstructured(GExprBody {
341 locals,
342 body: blocks,
343 ..
344 }) = &mut body
345 else {
346 return Ok(body);
347 };
348
349 let tupled_ty = &signature.inputs[1];
351
352 blocks.dyn_visit_mut(|local: &mut LocalId| {
353 if local.index() >= 2 {
354 *local += 1;
355 }
356 });
357
358 let mut old_locals = mem::take(&mut locals.locals).into_iter();
359 locals.arg_count = 2;
360 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg = locals.new_var(Some("tupled_args".to_string()), tupled_ty.clone());
363 locals.locals.extend(old_locals.map(|mut l| {
364 l.index += 1;
365 l
366 }));
367
368 let untupled_args = tupled_ty.as_tuple().unwrap();
369 let closure_arg_count = untupled_args.elem_count();
370 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
371 let nth_field = tupled_arg.clone().project(
372 ProjectionElem::Field(
373 FieldProjKind::Tuple(closure_arg_count),
374 FieldId::new(i),
375 ),
376 ty,
377 );
378 let local_id = LocalId::new(i + 3);
379 Statement::new(
380 span,
381 StatementKind::Assign(
382 locals.place_for_var(local_id),
383 Rvalue::Use(Operand::Move(nth_field)),
384 ),
385 )
386 });
387 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
388
389 body
390 }
391 (FnOnce, Fn | FnMut) => {
401 let Some(body) = def.this.closure_once_shim(self.hax_state()) else {
403 panic!("missing shim for closure")
404 };
405 self.translate_body(span, body, &def.source_text)
406 }
407 (FnMut, Fn) => {
414 let fun_id: FunDeclId = self.register_item(
415 span,
416 def.this(),
417 TransItemSourceKind::ClosureMethod(closure_kind),
418 );
419 let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
420 let fn_op = FnOperand::Regular(FnPtr::new(
423 fun_id.into(),
424 impl_ref.generics.concat(&GenericArgs {
425 regions: vec![self.translate_erased_region()].into(),
426 ..GenericArgs::empty()
427 }),
428 ));
429
430 let mut builder = BodyBuilder::new(span, 2);
431
432 let output = builder.new_var(None, signature.output.clone());
433 let state = builder.new_var(Some("state".to_string()), signature.inputs[0].clone());
434 let args = builder.new_var(Some("args".to_string()), signature.inputs[1].clone());
435 let deref_state = state.deref();
436 let reborrow_ty = TyKind::Ref(
437 self.translate_erased_region(),
438 deref_state.ty.clone(),
439 RefKind::Shared,
440 )
441 .into_ty();
442 let reborrow = builder.new_var(None, reborrow_ty);
443
444 builder.push_statement(StatementKind::Assign(
445 reborrow.clone(),
446 Rvalue::Ref {
447 place: deref_state,
448 kind: BorrowKind::Shared,
449 ptr_metadata: Operand::mk_const_unit(),
451 },
452 ));
453
454 builder.call(Call {
455 func: fn_op,
456 args: vec![Operand::Move(reborrow), Operand::Move(args)],
457 dest: output,
458 });
459
460 Body::Unstructured(builder.build())
461 }
462 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
463 panic!(
464 "Can't make a closure body for a more restrictive kind \
465 than the closure kind"
466 )
467 }
468 })
469 }
470
471 #[tracing::instrument(skip(self, item_meta))]
474 pub fn translate_closure_method(
475 mut self,
476 def_id: FunDeclId,
477 item_meta: ItemMeta,
478 def: &hax::FullDef,
479 target_kind: ClosureKind,
480 ) -> Result<FunDecl, Error> {
481 let span = item_meta.span;
482 let hax::FullDefKind::Closure {
483 args,
484 fn_once_impl,
485 fn_mut_impl,
486 fn_impl,
487 ..
488 } = &def.kind
489 else {
490 unreachable!()
491 };
492
493 let vimpl = match target_kind {
495 ClosureKind::FnOnce => fn_once_impl,
496 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
497 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
498 };
499 let implemented_trait = self.translate_trait_predicate(span, &vimpl.trait_pred)?;
500
501 let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
502 let src = ItemSource::TraitImpl {
503 impl_ref,
504 trait_ref: implemented_trait,
505 item_name: TraitItemName(target_kind.method_name().into()),
506 reuses_default: false,
507 };
508
509 let bound_sig = self.translate_closure_method_sig(def, span, args, target_kind)?;
511 let signature = bound_sig.apply(
513 self.the_only_binder()
514 .closure_call_method_region
515 .iter()
516 .map(|r| Region::Var(DeBruijnVar::new_at_zero(*r)))
517 .collect(),
518 );
519
520 let body = if item_meta.opacity.with_private_contents().is_opaque() {
521 Body::Opaque
522 } else {
523 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
524 };
525
526 Ok(FunDecl {
527 def_id,
528 item_meta,
529 generics: self.into_generics(),
530 signature,
531 src,
532 is_global_initializer: None,
533 body,
534 })
535 }
536
537 #[tracing::instrument(skip(self, item_meta))]
538 pub fn translate_closure_trait_impl(
539 mut self,
540 def_id: TraitImplId,
541 item_meta: ItemMeta,
542 def: &hax::FullDef,
543 target_kind: ClosureKind,
544 ) -> Result<TraitImpl, Error> {
545 let span = item_meta.span;
546 let hax::FullDefKind::Closure {
547 args,
548 fn_once_impl,
549 fn_mut_impl,
550 fn_impl,
551 ..
552 } = def.kind()
553 else {
554 unreachable!()
555 };
556
557 let vimpl = match target_kind {
559 ClosureKind::FnOnce => fn_once_impl,
560 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
561 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
562 };
563 let mut timpl = self.translate_virtual_trait_impl(def_id, item_meta, vimpl)?;
564
565 let call_fn_name = TraitItemName(target_kind.method_name().into());
567 let call_fn_binder = {
568 let kind = TransItemSourceKind::ClosureMethod(target_kind);
569 let bound_method_ref: RegionBinder<DeclRef<ItemId>> =
570 self.translate_closure_bound_ref_with_method_bound(span, args, kind, target_kind)?;
571 let params = GenericParams {
572 regions: bound_method_ref.regions,
573 ..GenericParams::empty()
574 };
575 let fn_decl_ref: FunDeclRef = bound_method_ref.skip_binder.try_into().unwrap();
576 Binder::new(
577 BinderKind::TraitMethod(timpl.impl_trait.id, call_fn_name),
578 params,
579 fn_decl_ref,
580 )
581 };
582 timpl.methods.push((call_fn_name, call_fn_binder));
583
584 Ok(timpl)
585 }
586
587 #[tracing::instrument(skip(self, item_meta))]
590 pub fn translate_stateless_closure_as_fn(
591 mut self,
592 def_id: FunDeclId,
593 item_meta: ItemMeta,
594 def: &hax::FullDef,
595 ) -> Result<FunDecl, Error> {
596 let span = item_meta.span;
597 let hax::FullDefKind::Closure { args: closure, .. } = &def.kind else {
598 unreachable!()
599 };
600
601 trace!("About to translate closure as fn:\n{:?}", def.def_id());
602
603 assert!(
604 closure.upvar_tys.is_empty(),
605 "Only stateless closures can be translated as functions"
606 );
607
608 let signature = self.translate_fun_sig(span, closure.fn_sig.hax_skip_binder_ref())?;
610 let state_ty = self.get_closure_state_ty(span, closure)?;
611
612 let body = if item_meta.opacity.with_private_contents().is_opaque() {
613 Body::Opaque
614 } else {
615 let fun_id: FunDeclId = self.register_item(
623 span,
624 def.this(),
625 TransItemSourceKind::ClosureMethod(ClosureKind::FnOnce),
626 );
627 let impl_ref = self.translate_closure_impl_ref(span, closure, ClosureKind::FnOnce)?;
628 let fn_op = FnOperand::Regular(FnPtr::new(fun_id.into(), impl_ref.generics.clone()));
629
630 let mut builder = BodyBuilder::new(span, signature.inputs.len());
631
632 let output = builder.new_var(None, signature.output.clone());
633 let args: Vec<Place> = signature
634 .inputs
635 .iter()
636 .enumerate()
637 .map(|(i, ty)| builder.new_var(Some(format!("arg{}", i + 1)), ty.clone()))
638 .collect();
639 let args_tupled_ty = Ty::mk_tuple(signature.inputs.clone());
640 let args_tupled = builder.new_var(Some("args".to_string()), args_tupled_ty.clone());
641 let state = builder.new_var(Some("state".to_string()), state_ty.clone());
642
643 builder.push_statement(StatementKind::Assign(
644 args_tupled.clone(),
645 Rvalue::Aggregate(
646 AggregateKind::Adt(args_tupled_ty.as_adt().unwrap().clone(), None, None),
647 args.into_iter().map(Operand::Move).collect(),
648 ),
649 ));
650
651 let state_ty_adt = state_ty.as_adt().unwrap();
652 builder.push_statement(StatementKind::Assign(
653 state.clone(),
654 Rvalue::Aggregate(AggregateKind::Adt(state_ty_adt.clone(), None, None), vec![]),
655 ));
656
657 builder.call(Call {
658 func: fn_op,
659 args: vec![Operand::Move(state), Operand::Move(args_tupled)],
660 dest: output,
661 });
662
663 Body::Unstructured(builder.build())
664 };
665
666 Ok(FunDecl {
667 def_id,
668 item_meta,
669 generics: self.into_generics(),
670 signature,
671 src: ItemSource::TopLevel,
672 is_global_initializer: None,
673 body,
674 })
675 }
676}