1use itertools::Itertools;
42use std::mem;
43
44use super::translate_crate::TransItemSourceKind;
45use super::translate_ctx::*;
46use charon_lib::ast::ullbc_ast_utils::BodyBuilder;
47use charon_lib::ast::*;
48use charon_lib::ids::IndexVec;
49use charon_lib::ullbc_ast::*;
50
51pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
52 match kind {
53 hax::ClosureKind::Fn => ClosureKind::Fn,
54 hax::ClosureKind::FnMut => ClosureKind::FnMut,
55 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
56 }
57}
58
59impl ItemTransCtx<'_, '_> {
64 fn translate_closure_bound_ref_with_late_bound(
67 &mut self,
68 span: Span,
69 closure: &hax::ClosureArgs,
70 kind: TransItemSourceKind,
71 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
72 if !matches!(
73 kind,
74 TransItemSourceKind::TraitImpl(..) | TransItemSourceKind::ClosureAsFnCast
75 ) {
76 raise_error!(
77 self,
78 span,
79 "Called `translate_closure_bound_ref_with_late_bound` on a `{kind:?}`; \
80 use `translate_closure_ref_with_upvars` \
81 or `translate_closure_bound_ref_with_method_bound` instead"
82 )
83 }
84 let dref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
85 self.translate_region_binder(span, &closure.fn_sig, |ctx, _| {
86 let mut dref = dref.move_under_binder();
87 for (a, b) in dref.generics.regions.iter_mut().rev().zip(
89 ctx.innermost_binder()
90 .params
91 .identity_args()
92 .regions
93 .into_iter()
94 .rev(),
95 ) {
96 *a = b;
97 }
98 Ok(dref)
99 })
100 }
101
102 fn translate_closure_bound_ref_with_method_bound(
106 &mut self,
107 span: Span,
108 closure: &hax::ClosureArgs,
109 kind: TransItemSourceKind,
110 target_kind: ClosureKind,
111 ) -> Result<RegionBinder<DeclRef<ItemId>>, Error> {
112 if !matches!(kind, TransItemSourceKind::ClosureMethod(..)) {
113 raise_error!(
114 self,
115 span,
116 "Called `translate_closure_bound_ref_with_method_bound` on a `{kind:?}`; \
117 use `translate_closure_ref_with_upvars` \
118 or `translate_closure_bound_ref_with_late_bound` instead"
119 )
120 }
121 let dref: DeclRef<ItemId> = self.translate_item(span, &closure.item, kind)?;
122 let mut dref = dref.move_under_binder();
123 let mut regions = IndexMap::new();
124 match target_kind {
125 ClosureKind::FnOnce => {}
126 ClosureKind::FnMut | ClosureKind::Fn => {
127 let rid = regions.push_with(|index| RegionParam { index, name: None });
128 *dref.generics.regions.iter_mut().last().unwrap() =
129 Region::Var(DeBruijnVar::new_at_zero(rid));
130 }
131 }
132 Ok(RegionBinder {
133 regions,
134 skip_binder: dref,
135 })
136 }
137}
138
139impl ItemTransCtx<'_, '_> {
140 pub fn translate_closure_type_ref(
142 &mut self,
143 span: Span,
144 closure: &hax::ClosureArgs,
145 ) -> Result<TypeDeclRef, Error> {
146 self.translate_item(span, &closure.item, TransItemSourceKind::Type)
147 }
148
149 pub fn translate_stateless_closure_as_fn_ref(
153 &mut self,
154 span: Span,
155 closure: &hax::ClosureArgs,
156 ) -> Result<RegionBinder<FunDeclRef>, Error> {
157 let kind = TransItemSourceKind::ClosureAsFnCast;
158 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
159 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
160 }
161
162 pub fn translate_closure_bound_impl_ref(
166 &mut self,
167 span: Span,
168 closure: &hax::ClosureArgs,
169 target_kind: ClosureKind,
170 ) -> Result<RegionBinder<TraitImplRef>, Error> {
171 let kind = TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind));
172 let bound_dref = self.translate_closure_bound_ref_with_late_bound(span, closure, kind)?;
173 Ok(bound_dref.map(|dref| dref.try_into().unwrap()))
174 }
175
176 pub fn translate_closure_impl_ref(
178 &mut self,
179 span: Span,
180 closure: &hax::ClosureArgs,
181 target_kind: ClosureKind,
182 ) -> Result<TraitImplRef, Error> {
183 self.translate_item(
184 span,
185 &closure.item,
186 TransItemSourceKind::TraitImpl(TraitImplSource::Closure(target_kind)),
187 )
188 }
189
190 pub fn translate_closure_info(
191 &mut self,
192 span: Span,
193 args: &hax::ClosureArgs,
194 ) -> Result<ClosureInfo, Error> {
195 use ClosureKind::*;
196 let kind = translate_closure_kind(&args.kind);
197
198 let fn_once_impl = self.translate_closure_bound_impl_ref(span, args, FnOnce)?;
199 let fn_mut_impl = if matches!(kind, FnMut | Fn) {
200 Some(self.translate_closure_bound_impl_ref(span, args, FnMut)?)
201 } else {
202 None
203 };
204 let fn_impl = if matches!(kind, Fn) {
205 Some(self.translate_closure_bound_impl_ref(span, args, Fn)?)
206 } else {
207 None
208 };
209 let signature = self.translate_poly_fun_sig(span, &args.fn_sig)?;
210 Ok(ClosureInfo {
211 kind,
212 fn_once_impl,
213 fn_mut_impl,
214 fn_impl,
215 signature,
216 })
217 }
218
219 pub fn get_closure_state_ty(
220 &mut self,
221 span: Span,
222 args: &hax::ClosureArgs,
223 ) -> Result<Ty, Error> {
224 let tref = self.translate_closure_type_ref(span, args)?;
225 Ok(TyKind::Adt(tref).into_ty())
226 }
227
228 pub fn translate_closure_upvar_tys(
232 &mut self,
233 span: Span,
234 args: &hax::ClosureArgs,
235 ) -> Result<IndexVec<FieldId, Ty>, Error> {
236 args.upvar_tys
237 .iter()
238 .map(|ty| self.translate_ty(span, ty))
239 .try_collect()
240 }
241
242 pub fn translate_closure_adt(
243 &mut self,
244 span: Span,
245 _args: &hax::ClosureArgs,
246 ) -> Result<TypeDeclKind, Error> {
247 let fields: IndexVec<FieldId, Field> = self
248 .the_only_binder()
249 .closure_upvar_tys
250 .as_ref()
251 .unwrap()
252 .iter()
253 .cloned()
254 .map(|ty| Field {
255 span,
256 attr_info: AttrInfo::dummy_private(),
257 name: None,
258 ty: ty,
259 })
260 .collect();
261 Ok(TypeDeclKind::Struct(fields))
262 }
263
264 fn translate_closure_method_sig(
267 &mut self,
268 def: &hax::FullDef,
269 span: Span,
270 args: &hax::ClosureArgs,
271 target_kind: ClosureKind,
272 ) -> Result<RegionBinder<FunSig>, Error> {
273 let signature = &args.fn_sig;
274 trace!(
275 "signature of closure {:?}:\n{:?}",
276 def.def_id(),
277 signature.value,
278 );
279
280 let mut bound_regions = IndexMap::new();
281 let mut fun_sig = self
282 .translate_fun_sig(span, signature.hax_skip_binder_ref())?
283 .move_under_binder();
284 let state_ty = self.get_closure_state_ty(span, args)?.move_under_binder();
285
286 let state_ty = match target_kind {
288 ClosureKind::FnOnce => state_ty,
289 ClosureKind::Fn | ClosureKind::FnMut => {
290 let rid = bound_regions.push_with(|index| RegionParam { index, name: None });
291 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
292 let mutability = if target_kind == ClosureKind::Fn {
293 RefKind::Shared
294 } else {
295 RefKind::Mut
296 };
297 TyKind::Ref(r, state_ty, mutability).into_ty()
298 }
299 };
300
301 let input_tys: Vec<Ty> = mem::take(&mut fun_sig.inputs);
303 fun_sig.inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
305
306 Ok(RegionBinder {
307 regions: bound_regions,
308 skip_binder: fun_sig,
309 })
310 }
311
312 fn translate_closure_method_body(
313 &mut self,
314 span: Span,
315 def: &hax::FullDef,
316 target_kind: ClosureKind,
317 args: &hax::ClosureArgs,
318 signature: &FunSig,
319 ) -> Result<Body, Error> {
320 use ClosureKind::*;
321 let closure_kind = translate_closure_kind(&args.kind);
322 Ok(match (target_kind, closure_kind) {
323 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
324 let mut body = self.translate_def_body(span, def);
326 let Body::Unstructured(GExprBody {
335 locals,
336 body: blocks,
337 ..
338 }) = &mut body
339 else {
340 return Ok(body);
341 };
342
343 let tupled_ty = &signature.inputs[1];
345
346 blocks.dyn_visit_mut(|local: &mut LocalId| {
347 if local.index() >= 2 {
348 *local += 1;
349 }
350 });
351
352 let mut old_locals = mem::take(&mut locals.locals).into_iter();
353 locals.arg_count = 2;
354 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg = locals.new_var(Some("tupled_args".to_string()), tupled_ty.clone());
357 locals.locals.extend(old_locals.map(|mut l| {
358 l.index += 1;
359 l
360 }));
361
362 let untupled_args = tupled_ty.as_tuple().unwrap();
363 let closure_arg_count = untupled_args.elem_count();
364 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
365 let nth_field = tupled_arg.clone().project(
366 ProjectionElem::Field(
367 FieldProjKind::Tuple(closure_arg_count),
368 FieldId::new(i),
369 ),
370 ty,
371 );
372 let local_id = LocalId::new(i + 3);
373 Statement::new(
374 span,
375 StatementKind::Assign(
376 locals.place_for_var(local_id),
377 Rvalue::Use(Operand::Move(nth_field)),
378 ),
379 )
380 });
381 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
382
383 body
384 }
385 (FnOnce, Fn | FnMut) => {
395 let Some(body) = def.this.closure_once_shim(self.hax_state()) else {
397 panic!("missing shim for closure")
398 };
399 self.translate_body(span, body, &def.source_text)
400 }
401 (FnMut, Fn) => {
408 let fun_id: FunDeclId = self.register_item(
409 span,
410 def.this(),
411 TransItemSourceKind::ClosureMethod(closure_kind),
412 );
413 let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
414 let fn_op = FnOperand::Regular(FnPtr::new(
417 fun_id.into(),
418 impl_ref.generics.concat(&GenericArgs {
419 regions: vec![self.translate_erased_region()].into(),
420 ..GenericArgs::empty()
421 }),
422 ));
423
424 let mut builder = BodyBuilder::new(span, 2);
425
426 let output = builder.new_var(None, signature.output.clone());
427 let state = builder.new_var(Some("state".to_string()), signature.inputs[0].clone());
428 let args = builder.new_var(Some("args".to_string()), signature.inputs[1].clone());
429 let deref_state = state.deref();
430 let reborrow_ty = TyKind::Ref(
431 self.translate_erased_region(),
432 deref_state.ty.clone(),
433 RefKind::Shared,
434 )
435 .into_ty();
436 let reborrow = builder.new_var(None, reborrow_ty);
437
438 builder.push_statement(StatementKind::Assign(
439 reborrow.clone(),
440 Rvalue::Ref {
441 place: deref_state,
442 kind: BorrowKind::Shared,
443 ptr_metadata: Operand::mk_const_unit(),
445 },
446 ));
447
448 builder.call(Call {
449 func: fn_op,
450 args: vec![Operand::Move(reborrow), Operand::Move(args)],
451 dest: output,
452 });
453
454 Body::Unstructured(builder.build())
455 }
456 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
457 panic!(
458 "Can't make a closure body for a more restrictive kind \
459 than the closure kind"
460 )
461 }
462 })
463 }
464
465 #[tracing::instrument(skip(self, item_meta))]
468 pub fn translate_closure_method(
469 mut self,
470 def_id: FunDeclId,
471 item_meta: ItemMeta,
472 def: &hax::FullDef,
473 target_kind: ClosureKind,
474 ) -> Result<FunDecl, Error> {
475 let span = item_meta.span;
476 let hax::FullDefKind::Closure {
477 args,
478 fn_once_impl,
479 fn_mut_impl,
480 fn_impl,
481 ..
482 } = &def.kind
483 else {
484 unreachable!()
485 };
486
487 let vimpl = match target_kind {
489 ClosureKind::FnOnce => fn_once_impl,
490 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
491 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
492 };
493 let implemented_trait = self.translate_trait_predicate(span, &vimpl.trait_pred)?;
494
495 let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
496 let src = ItemSource::TraitImpl {
497 impl_ref,
498 trait_ref: implemented_trait,
499 item_name: TraitItemName(target_kind.method_name().into()),
500 reuses_default: false,
501 };
502
503 let bound_sig = self.translate_closure_method_sig(def, span, args, target_kind)?;
505 let signature = bound_sig.apply(
507 self.the_only_binder()
508 .closure_call_method_region
509 .iter()
510 .map(|r| Region::Var(DeBruijnVar::new_at_zero(*r)))
511 .collect(),
512 );
513
514 let body = if item_meta.opacity.with_private_contents().is_opaque() {
515 Body::Opaque
516 } else {
517 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
518 };
519
520 Ok(FunDecl {
521 def_id,
522 item_meta,
523 generics: self.into_generics(),
524 signature,
525 src,
526 is_global_initializer: None,
527 body,
528 })
529 }
530
531 #[tracing::instrument(skip(self, item_meta))]
532 pub fn translate_closure_trait_impl(
533 mut self,
534 def_id: TraitImplId,
535 item_meta: ItemMeta,
536 def: &hax::FullDef,
537 target_kind: ClosureKind,
538 ) -> Result<TraitImpl, Error> {
539 let span = item_meta.span;
540 let hax::FullDefKind::Closure {
541 args,
542 fn_once_impl,
543 fn_mut_impl,
544 fn_impl,
545 ..
546 } = def.kind()
547 else {
548 unreachable!()
549 };
550
551 let vimpl = match target_kind {
553 ClosureKind::FnOnce => fn_once_impl,
554 ClosureKind::FnMut => fn_mut_impl.as_ref().unwrap(),
555 ClosureKind::Fn => fn_impl.as_ref().unwrap(),
556 };
557 let mut timpl = self.translate_virtual_trait_impl(def_id, item_meta, vimpl)?;
558
559 let call_fn_name = TraitItemName(target_kind.method_name().into());
561 let call_fn_binder = {
562 let kind = TransItemSourceKind::ClosureMethod(target_kind);
563 let bound_method_ref: RegionBinder<DeclRef<ItemId>> =
564 self.translate_closure_bound_ref_with_method_bound(span, args, kind, target_kind)?;
565 let params = GenericParams {
566 regions: bound_method_ref.regions,
567 ..GenericParams::empty()
568 };
569 let fn_decl_ref: FunDeclRef = bound_method_ref.skip_binder.try_into().unwrap();
570 Binder::new(
571 BinderKind::TraitMethod(timpl.impl_trait.id, call_fn_name),
572 params,
573 fn_decl_ref,
574 )
575 };
576 timpl.methods.push((call_fn_name, call_fn_binder));
577
578 Ok(timpl)
579 }
580
581 #[tracing::instrument(skip(self, item_meta))]
584 pub fn translate_stateless_closure_as_fn(
585 mut self,
586 def_id: FunDeclId,
587 item_meta: ItemMeta,
588 def: &hax::FullDef,
589 ) -> Result<FunDecl, Error> {
590 let span = item_meta.span;
591 let hax::FullDefKind::Closure { args: closure, .. } = &def.kind else {
592 unreachable!()
593 };
594
595 trace!("About to translate closure as fn:\n{:?}", def.def_id());
596
597 assert!(
598 closure.upvar_tys.is_empty(),
599 "Only stateless closures can be translated as functions"
600 );
601
602 let signature = self.translate_fun_sig(span, closure.fn_sig.hax_skip_binder_ref())?;
604 let state_ty = self.get_closure_state_ty(span, closure)?;
605
606 let body = if item_meta.opacity.with_private_contents().is_opaque() {
607 Body::Opaque
608 } else {
609 let fun_id: FunDeclId = self.register_item(
617 span,
618 def.this(),
619 TransItemSourceKind::ClosureMethod(ClosureKind::FnOnce),
620 );
621 let impl_ref = self.translate_closure_impl_ref(span, closure, ClosureKind::FnOnce)?;
622 let fn_op = FnOperand::Regular(FnPtr::new(fun_id.into(), impl_ref.generics.clone()));
623
624 let mut builder = BodyBuilder::new(span, signature.inputs.len());
625
626 let output = builder.new_var(None, signature.output.clone());
627 let args: Vec<Place> = signature
628 .inputs
629 .iter()
630 .enumerate()
631 .map(|(i, ty)| builder.new_var(Some(format!("arg{}", i + 1)), ty.clone()))
632 .collect();
633 let args_tupled_ty = Ty::mk_tuple(signature.inputs.clone());
634 let args_tupled = builder.new_var(Some("args".to_string()), args_tupled_ty.clone());
635 let state = builder.new_var(Some("state".to_string()), state_ty.clone());
636
637 builder.push_statement(StatementKind::Assign(
638 args_tupled.clone(),
639 Rvalue::Aggregate(
640 AggregateKind::Adt(args_tupled_ty.as_adt().unwrap().clone(), None, None),
641 args.into_iter().map(Operand::Move).collect(),
642 ),
643 ));
644
645 let state_ty_adt = state_ty.as_adt().unwrap();
646 builder.push_statement(StatementKind::Assign(
647 state.clone(),
648 Rvalue::Aggregate(AggregateKind::Adt(state_ty_adt.clone(), None, None), vec![]),
649 ));
650
651 builder.call(Call {
652 func: fn_op,
653 args: vec![Operand::Move(state), Operand::Move(args_tupled)],
654 dest: output,
655 });
656
657 Body::Unstructured(builder.build())
658 };
659
660 Ok(FunDecl {
661 def_id,
662 item_meta,
663 generics: self.into_generics(),
664 signature,
665 src: ItemSource::TopLevel,
666 is_global_initializer: None,
667 body,
668 })
669 }
670}