1use crate::hax;
42use itertools::Itertools;
43use std::mem;
44
45use super::translate_crate::TransItemSourceKind;
46use super::translate_ctx::*;
47use charon_lib::ast::ullbc_ast_utils::BodyBuilder;
48use charon_lib::ast::*;
49use charon_lib::ids::IndexVec;
50use charon_lib::ullbc_ast::*;
51
52pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
53 match kind {
54 hax::ClosureKind::Fn => ClosureKind::Fn,
55 hax::ClosureKind::FnMut => ClosureKind::FnMut,
56 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
57 }
58}
59
60impl ItemTransCtx<'_, '_> {
65 fn translate_closure_bound_ref_with_late_bound(
68 &mut self,
69 span: Span,
70 closure: &hax::ClosureArgs,
71 kind: TransItemSourceKind,
72 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
73 if !matches!(
74 kind,
75 TransItemSourceKind::TraitImpl(..) | TransItemSourceKind::ClosureAsFnCast
76 ) {
77 raise_error!(
78 self,
79 span,
80 "Called `translate_closure_bound_ref_with_late_bound` on a `{kind:?}`; \
81 use `translate_closure_ref_with_upvars` \
82 or `translate_closure_bound_ref_with_method_bound` instead"
83 )
84 }
85 let dref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
86 self.translate_region_binder(span, &closure.fn_sig, |ctx, _| {
87 let mut dref = dref.move_under_binder();
88 for (a, b) in dref.generics.regions.iter_mut().rev().zip(
90 ctx.innermost_binder()
91 .params
92 .identity_args()
93 .regions
94 .into_iter()
95 .rev(),
96 ) {
97 *a = b;
98 }
99 Ok(dref)
100 })
101 }
102
103 fn translate_closure_bound_ref_with_method_bound(
107 &mut self,
108 span: Span,
109 closure: &hax::ClosureArgs,
110 kind: TransItemSourceKind,
111 target_kind: ClosureKind,
112 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
113 if !matches!(kind, TransItemSourceKind::ClosureMethod(..)) {
114 raise_error!(
115 self,
116 span,
117 "Called `translate_closure_bound_ref_with_method_bound` on a `{kind:?}`; \
118 use `translate_closure_ref_with_upvars` \
119 or `translate_closure_bound_ref_with_late_bound` instead"
120 )
121 }
122 let dref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
123 let mut dref = dref.move_under_binder();
124 let mut regions = IndexVec::new();
125 match target_kind {
126 ClosureKind::FnOnce => {}
127 ClosureKind::FnMut | ClosureKind::Fn => {
128 let rid = regions.push_with(|index| RegionParam::new(index, None));
129 *dref.generics.regions.iter_mut().last().unwrap() =
130 Region::Var(DeBruijnVar::new_at_zero(rid));
131 }
132 }
133 Ok(RegionBinder {
134 regions,
135 skip_binder: dref,
136 })
137 }
138}
139
140impl ItemTransCtx<'_, '_> {
141 pub fn translate_closure_type_ref(
143 &mut self,
144 span: Span,
145 closure: &hax::ClosureArgs,
146 ) -> Result<TypeDeclRef, Error> {
147 self.translate_item(span, &closure.item, TransItemSourceKind::Type)
148 }
149
150 pub fn translate_stateless_closure_as_fn_ref(
154 &mut self,
155 span: Span,
156 closure: &hax::ClosureArgs,
157 ) -> Result<RegionBinder<FunDeclRef>, Error> {
158 let kind = TransItemSourceKind::ClosureAsFnCast;
159 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
160 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
161 }
162
163 pub fn translate_closure_bound_impl_ref(
167 &mut self,
168 span: Span,
169 closure: &hax::ClosureArgs,
170 target_kind: ClosureKind,
171 ) -> Result<RegionBinder<TraitImplRef>, Error> {
172 let kind = TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind));
173 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
174 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
175 }
176
177 pub fn translate_closure_impl_ref(
179 &mut self,
180 span: Span,
181 closure: &hax::ClosureArgs,
182 target_kind: ClosureKind,
183 ) -> Result<TraitImplRef, Error> {
184 self.translate_item(
185 span,
186 &closure.item,
187 TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind)),
188 )
189 }
190
191 pub fn translate_closure_info(
192 &mut self,
193 span: Span,
194 args: &hax::ClosureArgs,
195 ) -> Result<ClosureInfo, Error> {
196 use ClosureKind::*;
197 let kind = translate_closure_kind(&args.kind);
198
199 let fn_once_impl = self.translate_closure_bound_impl_ref(span, args, FnOnce)?;
200 let fn_mut_impl = if matches!(kind, FnMut | Fn) {
201 Some(self.translate_closure_bound_impl_ref(span, args, FnMut)?)
202 } else {
203 None
204 };
205 let fn_impl = if matches!(kind, Fn) {
206 Some(self.translate_closure_bound_impl_ref(span, args, Fn)?)
207 } else {
208 None
209 };
210 let signature = self.translate_poly_fun_sig(span, &args.fn_sig)?;
211 Ok(ClosureInfo {
212 kind,
213 fn_once_impl,
214 fn_mut_impl,
215 fn_impl,
216 signature,
217 })
218 }
219
220 pub fn get_closure_state_ty(
221 &mut self,
222 span: Span,
223 args: &hax::ClosureArgs,
224 ) -> Result<Ty, Error> {
225 let tref = self.translate_closure_type_ref(span, args)?;
226 Ok(TyKind::Adt(tref).into_ty())
227 }
228
229 pub fn translate_closure_upvar_tys(
233 &mut self,
234 span: Span,
235 args: &hax::ClosureArgs,
236 ) -> Result<IndexVec<FieldId, Ty>, Error> {
237 args.upvar_tys
238 .iter()
239 .map(|ty| self.translate_ty(span, ty))
240 .try_collect()
241 }
242
243 pub fn translate_closure_adt(
244 &mut self,
245 span: Span,
246 _args: &hax::ClosureArgs,
247 ) -> Result<TypeDeclKind, Error> {
248 let fields: IndexVec<FieldId, Field> = self
249 .the_only_binder()
250 .closure_upvar_tys
251 .as_ref()
252 .unwrap()
253 .iter()
254 .cloned()
255 .map(|ty| Field {
256 span,
257 attr_info: AttrInfo::dummy_private(),
258 name: None,
259 ty,
260 })
261 .collect();
262 Ok(TypeDeclKind::Struct(fields))
263 }
264
265 fn translate_closure_method_sig(
268 &mut self,
269 def: &hax::FullDef,
270 span: Span,
271 args: &hax::ClosureArgs,
272 target_kind: ClosureKind,
273 ) -> Result<RegionBinder<FunSig>, Error> {
274 let signature = &args.fn_sig;
275 trace!(
276 "signature of closure {:?}:\n{:?}",
277 def.def_id(),
278 signature.value,
279 );
280
281 let mut bound_regions = IndexVec::new();
282 let mut fun_sig = self
283 .translate_fun_sig(span, signature.hax_skip_binder_ref())?
284 .move_under_binder();
285 let state_ty = self.get_closure_state_ty(span, args)?.move_under_binder();
286
287 let state_ty = match target_kind {
289 ClosureKind::FnOnce => state_ty,
290 ClosureKind::Fn | ClosureKind::FnMut => {
291 let rid = bound_regions.push_with(|index| RegionParam::new(index, None));
292 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
293 let mutability = if target_kind == ClosureKind::Fn {
294 RefKind::Shared
295 } else {
296 RefKind::Mut
297 };
298 TyKind::Ref(r, state_ty, mutability).into_ty()
299 }
300 };
301
302 let input_tys: Vec<Ty> = mem::take(&mut fun_sig.inputs);
304 fun_sig.inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
306
307 Ok(RegionBinder {
308 regions: bound_regions,
309 skip_binder: fun_sig,
310 })
311 }
312
313 fn translate_closure_method_body(
314 &mut self,
315 span: Span,
316 def: &hax::FullDef,
317 target_kind: ClosureKind,
318 args: &hax::ClosureArgs,
319 signature: &FunSig,
320 ) -> Result<Body, Error> {
321 use ClosureKind::*;
322 let closure_kind = translate_closure_kind(&args.kind);
323 Ok(match (target_kind, closure_kind) {
324 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
325 let mut body = self.translate_def_body(span, def);
327 let Body::Unstructured(GExprBody {
336 locals,
337 body: blocks,
338 ..
339 }) = &mut body
340 else {
341 return Ok(body);
342 };
343
344 let tupled_ty = &signature.inputs[1];
346
347 blocks.dyn_visit_mut(|local: &mut LocalId| {
348 if local.index() >= 2 {
349 *local += 1;
350 }
351 });
352
353 let mut old_locals = mem::take(&mut locals.locals).into_iter();
354 locals.arg_count = 2;
355 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg = locals.new_var(Some("tupled_args".to_string()), tupled_ty.clone());
358 locals.locals.extend(old_locals.map(|mut l| {
359 l.index += 1;
360 l
361 }));
362
363 let untupled_args = tupled_ty.as_tuple().unwrap();
364 let closure_arg_count = untupled_args.len();
365 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
366 let nth_field = tupled_arg.clone().project(
367 ProjectionElem::Field(
368 FieldProjKind::Tuple(closure_arg_count),
369 FieldId::new(i),
370 ),
371 ty,
372 );
373 let local_id = LocalId::new(i + 3);
374 Statement::new(
375 span,
376 StatementKind::Assign(
377 locals.place_for_var(local_id),
378 Rvalue::Use(Operand::Move(nth_field)),
379 ),
380 )
381 });
382 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
383
384 body
385 }
386 (FnOnce, Fn | FnMut) => {
396 let Some(body) = def.this.closure_once_shim(self.hax_state()) else {
398 panic!("missing shim for closure")
399 };
400 self.translate_body(span, body, &def.source_text)
401 }
402 (FnMut, Fn) => {
409 let fun_id: FunDeclId = self.register_item(
410 span,
411 def.this(),
412 TransItemSourceKind::ClosureMethod(closure_kind),
413 );
414 let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
415 let fn_op = FnOperand::Regular(FnPtr::new(
418 fun_id.into(),
419 impl_ref.generics.concat(&GenericArgs {
420 regions: vec![self.translate_erased_region()].into(),
421 ..GenericArgs::empty()
422 }),
423 ));
424
425 let mut builder = BodyBuilder::new(span, 2);
426
427 let output = builder.new_var(None, signature.output.clone());
428 let state = builder.new_var(Some("state".to_string()), signature.inputs[0].clone());
429 let args = builder.new_var(Some("args".to_string()), signature.inputs[1].clone());
430 let deref_state = state.deref();
431 let reborrow_ty = TyKind::Ref(
432 self.translate_erased_region(),
433 deref_state.ty.clone(),
434 RefKind::Shared,
435 )
436 .into_ty();
437 let reborrow = builder.new_var(None, reborrow_ty);
438
439 builder.push_statement(StatementKind::Assign(
440 reborrow.clone(),
441 Rvalue::Ref {
442 place: deref_state,
443 kind: BorrowKind::Shared,
444 ptr_metadata: Operand::mk_const_unit(),
446 },
447 ));
448
449 builder.call(Call {
450 func: fn_op,
451 args: vec![Operand::Move(reborrow), Operand::Move(args)],
452 dest: output,
453 });
454
455 Body::Unstructured(builder.build())
456 }
457 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
458 panic!(
459 "Can't make a closure body for a more restrictive kind \
460 than the closure kind"
461 )
462 }
463 })
464 }
465
466 #[tracing::instrument(skip(self, item_meta))]
469 pub fn translate_closure_method(
470 mut self,
471 def_id: FunDeclId,
472 item_meta: ItemMeta,
473 def: &hax::FullDef,
474 target_kind: ClosureKind,
475 ) -> Result<FunDecl, Error> {
476 let span = item_meta.span;
477 let hax::FullDefKind::Closure {
478 args,
479 fn_once_impl,
480 fn_mut_impl,
481 fn_impl,
482 ..
483 } = &def.kind
484 else {
485 unreachable!()
486 };
487
488 let vimpl = match target_kind {
490 ClosureKind::FnOnce => fn_once_impl,
491 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
492 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
493 };
494 let implemented_trait = self.translate_trait_predicate(span, &vimpl.trait_pred)?;
495
496 let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
497 let src = ItemSource::TraitImpl {
498 impl_ref,
499 trait_ref: implemented_trait,
500 item_name: TraitItemName(target_kind.method_name().into()),
501 reuses_default: false,
502 };
503
504 let bound_sig = self.translate_closure_method_sig(def, span, args, target_kind)?;
506 let signature = bound_sig.apply(
508 self.the_only_binder()
509 .closure_call_method_region
510 .iter()
511 .map(|r| Region::Var(DeBruijnVar::new_at_zero(*r)))
512 .collect(),
513 );
514
515 let body = if item_meta.opacity.with_private_contents().is_opaque() {
516 Body::Opaque
517 } else {
518 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
519 };
520
521 Ok(FunDecl {
522 def_id,
523 item_meta,
524 generics: self.into_generics(),
525 signature,
526 src,
527 is_global_initializer: None,
528 body,
529 })
530 }
531
532 #[tracing::instrument(skip(self, item_meta))]
533 pub fn translate_closure_trait_impl(
534 mut self,
535 def_id: TraitImplId,
536 item_meta: ItemMeta,
537 def: &hax::FullDef,
538 target_kind: ClosureKind,
539 ) -> Result<TraitImpl, Error> {
540 let span = item_meta.span;
541 let hax::FullDefKind::Closure {
542 args,
543 fn_once_impl,
544 fn_mut_impl,
545 fn_impl,
546 ..
547 } = def.kind()
548 else {
549 unreachable!()
550 };
551
552 let vimpl = match target_kind {
554 ClosureKind::FnOnce => fn_once_impl,
555 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
556 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
557 };
558 let mut timpl = self.translate_virtual_trait_impl(def_id, item_meta, vimpl)?;
559
560 if self.monomorphize() {
561 return Ok(timpl);
562 }
563
564 let call_fn_name = TraitItemName(target_kind.method_name().into());
566 let call_fn_binder = {
567 let kind = TransItemSourceKind::ClosureMethod(target_kind);
568 let bound_method_ref: RegionBinder<DeclRef<ItemId>> =
569 self.translate_closure_bound_ref_with_method_bound(span, args, kind, target_kind)?;
570 let params = GenericParams {
571 regions: bound_method_ref.regions,
572 ..GenericParams::empty()
573 };
574 let fn_decl_ref: FunDeclRef = bound_method_ref.skip_binder.try_into().unwrap();
575 Binder::new(
576 BinderKind::TraitMethod(timpl.impl_trait.id, call_fn_name),
577 params,
578 fn_decl_ref,
579 )
580 };
581 timpl.methods.push((call_fn_name, call_fn_binder));
582
583 Ok(timpl)
584 }
585
586 #[tracing::instrument(skip(self, item_meta))]
589 pub fn translate_stateless_closure_as_fn(
590 mut self,
591 def_id: FunDeclId,
592 item_meta: ItemMeta,
593 def: &hax::FullDef,
594 ) -> Result<FunDecl, Error> {
595 let span = item_meta.span;
596 let hax::FullDefKind::Closure { args: closure, .. } = &def.kind else {
597 unreachable!()
598 };
599
600 trace!("About to translate closure as fn:\n{:?}", def.def_id());
601
602 assert!(
603 closure.upvar_tys.is_empty(),
604 "Only stateless closures can be translated as functions"
605 );
606
607 let signature = self.translate_fun_sig(span, closure.fn_sig.hax_skip_binder_ref())?;
609 let state_ty = self.get_closure_state_ty(span, closure)?;
610
611 let body = if item_meta.opacity.with_private_contents().is_opaque() {
612 Body::Opaque
613 } else {
614 let fun_id: FunDeclId = self.register_item(
622 span,
623 def.this(),
624 TransItemSourceKind::ClosureMethod(ClosureKind::FnOnce),
625 );
626 let impl_ref = self.translate_closure_impl_ref(span, closure, ClosureKind::FnOnce)?;
627 let fn_op = FnOperand::Regular(FnPtr::new(fun_id.into(), impl_ref.generics.clone()));
628
629 let mut builder = BodyBuilder::new(span, signature.inputs.len());
630
631 let output = builder.new_var(None, signature.output.clone());
632 let args: Vec<Place> = signature
633 .inputs
634 .iter()
635 .enumerate()
636 .map(|(i, ty)| builder.new_var(Some(format!("arg{}", i + 1)), ty.clone()))
637 .collect();
638 let args_tupled_ty = Ty::mk_tuple(signature.inputs.clone());
639 let args_tupled = builder.new_var(Some("args".to_string()), args_tupled_ty.clone());
640 let state = builder.new_var(Some("state".to_string()), state_ty.clone());
641
642 builder.push_statement(StatementKind::Assign(
643 args_tupled.clone(),
644 Rvalue::Aggregate(
645 AggregateKind::Adt(args_tupled_ty.as_adt().unwrap().clone(), None, None),
646 args.into_iter().map(Operand::Move).collect(),
647 ),
648 ));
649
650 let state_ty_adt = state_ty.as_adt().unwrap();
651 builder.push_statement(StatementKind::Assign(
652 state.clone(),
653 Rvalue::Aggregate(AggregateKind::Adt(state_ty_adt.clone(), None, None), vec![]),
654 ));
655
656 builder.call(Call {
657 func: fn_op,
658 args: vec![Operand::Move(state), Operand::Move(args_tupled)],
659 dest: output,
660 });
661
662 Body::Unstructured(builder.build())
663 };
664
665 Ok(FunDecl {
666 def_id,
667 item_meta,
668 generics: self.into_generics(),
669 signature,
670 src: ItemSource::TopLevel,
671 is_global_initializer: None,
672 body,
673 })
674 }
675}