1use std::mem;
42
43use crate::translate::translate_bodies::BodyTransCtx;
44
45use super::translate_ctx::*;
46use charon_lib::ast::*;
47use charon_lib::common::*;
48use charon_lib::formatter::IntoFormatter;
49use charon_lib::ids::Vector;
50use charon_lib::pretty::FmtWithCtx;
51use charon_lib::ullbc_ast::*;
52use hax_frontend_exporter as hax;
53use itertools::Itertools;
54
55pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
56 match kind {
57 hax::ClosureKind::Fn => ClosureKind::Fn,
58 hax::ClosureKind::FnMut => ClosureKind::FnMut,
59 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
60 }
61}
62
63impl ItemTransCtx<'_, '_> {
64 pub fn translate_closure_info(
65 &mut self,
66 span: Span,
67 args: &hax::ClosureArgs,
68 ) -> Result<ClosureInfo, Error> {
69 let kind = translate_closure_kind(&args.kind);
70 let signature = self.translate_region_binder(span, &args.untupled_sig, |ctx, sig| {
71 let inputs = sig
72 .inputs
73 .iter()
74 .map(|x| ctx.translate_ty(span, x))
75 .try_collect()?;
76 let output = ctx.translate_ty(span, &sig.output)?;
77 Ok((inputs, output))
78 })?;
79 Ok(ClosureInfo { kind, signature })
80 }
81
82 pub fn translate_closure_type_ref(
84 &mut self,
85 span: Span,
86 def_id: &hax::DefId,
87 closure: &hax::ClosureArgs,
88 ) -> Result<TypeDeclRef, Error> {
89 let type_id = self.register_type_decl_id(span, def_id);
90 let mut args = self.translate_generic_args(
91 span,
92 &closure.parent_args,
93 &closure.parent_trait_refs,
94 None,
95 GenericsSource::item(type_id),
96 )?;
97 if self.def_id == *def_id {
99 args.regions.extend(
100 self.outermost_binder()
101 .by_ref_upvar_regions
102 .iter()
103 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r))),
104 );
105 } else {
106 args.regions.extend(
107 closure
108 .upvar_tys
109 .iter()
110 .filter(|ty| {
111 matches!(
112 ty.kind(),
113 hax::TyKind::Ref(
114 hax::Region {
115 kind: hax::RegionKind::ReErased
116 },
117 ..
118 )
119 )
120 })
121 .map(|_| Region::Erased),
122 );
123 }
124
125 Ok(TypeDeclRef {
126 id: TypeId::Adt(type_id),
127 generics: Box::new(args),
128 })
129 }
130
131 pub fn translate_closure_impl_ref(
133 &mut self,
134 span: Span,
135 def_id: &hax::DefId,
136 closure: &hax::ClosureArgs,
137 target_kind: ClosureKind,
138 ) -> Result<TraitImplRef, Error> {
139 let impl_id = self.register_closure_trait_impl_id(span, def_id, target_kind);
140 let adt_ref = self.translate_closure_type_ref(span, def_id, closure)?;
141 let mut args = adt_ref.generics.with_target(GenericsSource::item(impl_id));
142 if self.def_id == *def_id {
144 args.regions.extend(
145 self.outermost_binder()
146 .bound_region_vars
147 .iter()
148 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r))),
149 );
150 } else {
151 args.regions
152 .extend(closure.tupled_sig.bound_vars.iter().map(|_| Region::Erased));
153 }
154
155 Ok(TraitImplRef {
156 impl_id,
157 generics: Box::new(args),
158 })
159 }
160
161 pub fn translate_closure_trait_ref(
163 &mut self,
164 span: Span,
165 def_id: &hax::DefId,
166 args: &hax::ClosureArgs,
167 target_kind: ClosureKind,
168 ) -> Result<TraitDeclRef, Error> {
169 let fn_trait = match target_kind {
171 ClosureKind::FnOnce => self.get_lang_item(rustc_hir::LangItem::FnOnce),
172 ClosureKind::FnMut => self.get_lang_item(rustc_hir::LangItem::FnMut),
173 ClosureKind::Fn => self.get_lang_item(rustc_hir::LangItem::Fn),
174 };
175 let trait_id = self.register_trait_decl_id(span, &fn_trait);
176
177 let state_ty = self.get_closure_state_ty(span, def_id, args)?;
178 let (inputs, _) = self.translate_closure_info(span, args)?.signature.erase();
180 let input_tuple = Ty::mk_tuple(inputs);
181
182 Ok(TraitDeclRef {
183 trait_id,
184 generics: Box::new(GenericArgs::new_types(
185 [state_ty, input_tuple].into(),
186 GenericsSource::item(trait_id),
187 )),
188 })
189 }
190
191 pub fn get_closure_state_ty(
192 &mut self,
193 span: Span,
194 def_id: &hax::DefId,
195 args: &hax::ClosureArgs,
196 ) -> Result<Ty, Error> {
197 let ty_ref = self.translate_closure_type_ref(span, def_id, args)?;
198 Ok(TyKind::Adt(ty_ref.id, *ty_ref.generics).into_ty())
199 }
200
201 pub fn translate_closure_adt(
202 &mut self,
203 _trans_id: TypeDeclId,
204 span: Span,
205 args: &hax::ClosureArgs,
206 ) -> Result<TypeDeclKind, Error> {
207 let mut by_ref_upvar_regions = self
208 .the_only_binder()
209 .by_ref_upvar_regions
210 .clone()
211 .into_iter();
212 let fields: Vector<FieldId, Field> = args
213 .upvar_tys
214 .iter()
215 .map(|ty| {
216 let mut ty = self.translate_ty(span, ty)?;
217 if let TyKind::Ref(Region::Erased, deref_ty, kind) = ty.kind() {
219 let region_id = by_ref_upvar_regions.next().unwrap();
220 ty = TyKind::Ref(
221 Region::Var(DeBruijnVar::new_at_zero(region_id)),
222 deref_ty.clone(),
223 *kind,
224 )
225 .into_ty();
226 }
227 Ok(Field {
228 span,
229 attr_info: AttrInfo {
230 attributes: vec![],
231 inline: None,
232 rename: None,
233 public: false,
234 },
235 name: None,
236 ty,
237 })
238 })
239 .try_collect()?;
240 Ok(TypeDeclKind::Struct(fields))
241 }
242
243 fn translate_closure_method_sig(
246 &mut self,
247 def: &hax::FullDef,
248 span: Span,
249 args: &hax::ClosureArgs,
250 target_kind: ClosureKind,
251 ) -> Result<FunSig, Error> {
252 let signature = &args.tupled_sig;
253
254 trace!(
256 "signature of closure {:?}:\n{:?}",
257 def.def_id,
258 signature.value
259 );
260 let mut inputs: Vec<Ty> = signature
261 .value
262 .inputs
263 .iter()
264 .map(|ty| self.translate_ty(span, ty))
265 .try_collect()?;
266 let output = self.translate_ty(span, &signature.value.output)?;
267
268 let fmt_ctx = &self.into_fmt();
269 trace!(
270 "# Input variables types:\n{}",
271 pretty_display_list(|x| x.to_string_with_ctx(fmt_ctx), &inputs)
272 );
273 trace!("# Output variable type:\n{}", output.with_ctx(fmt_ctx));
274
275 let is_unsafe = match signature.value.safety {
276 hax::Safety::Unsafe => true,
277 hax::Safety::Safe => false,
278 };
279
280 let state_ty = self.get_closure_state_ty(span, def.def_id(), args)?;
281
282 let state_ty = match target_kind {
284 ClosureKind::FnOnce => state_ty,
285 ClosureKind::Fn | ClosureKind::FnMut => {
286 let rid = self
287 .innermost_generics_mut()
288 .regions
289 .push_with(|index| RegionVar { index, name: None });
290 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
291 let mutability = if target_kind == ClosureKind::Fn {
292 RefKind::Shared
293 } else {
294 RefKind::Mut
295 };
296 TyKind::Ref(r, state_ty, mutability).into_ty()
297 }
298 };
299 assert_eq!(inputs.len(), 1);
300 inputs.insert(0, state_ty);
301
302 Ok(FunSig {
303 generics: self.the_only_binder().params.clone(),
304 is_unsafe,
305 inputs,
306 output,
307 })
308 }
309
310 fn translate_closure_method_body(
311 mut self,
312 span: Span,
313 def: &hax::FullDef,
314 target_kind: ClosureKind,
315 args: &hax::ClosureArgs,
316 signature: &FunSig,
317 ) -> Result<Result<Body, Opaque>, Error> {
318 use ClosureKind::*;
319 let closure_kind = translate_closure_kind(&args.kind);
320 let mk_stt = |content| Statement::new(span, content);
321 let mk_block = |statements, terminator| -> BlockData {
322 BlockData {
323 statements,
324 terminator: Terminator::new(span, terminator),
325 }
326 };
327
328 Ok(match (target_kind, closure_kind) {
329 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
330 let mut bt_ctx = BodyTransCtx::new(&mut self);
332 match bt_ctx.translate_def_body(span, def) {
333 Ok(Ok(mut body)) => {
334 let GExprBody {
343 locals,
344 body: blocks,
345 ..
346 } = body.as_unstructured_mut().unwrap();
347
348 blocks.dyn_visit_mut(|local: &mut LocalId| {
349 let idx = local.index();
350 if idx >= 2 {
351 *local = LocalId::new(idx + 1)
352 }
353 });
354
355 let mut old_locals = mem::take(&mut locals.locals).into_iter();
356 locals.arg_count = 2;
357 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg = locals
360 .new_var(Some("tupled_args".to_string()), signature.inputs[1].clone());
361 locals.locals.extend(old_locals.map(|mut l| {
362 l.index += 1;
363 l
364 }));
365
366 let untupled_args = signature.inputs[1].as_tuple().unwrap();
367 let closure_arg_count = untupled_args.elem_count();
368 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
369 let nth_field = tupled_arg.clone().project(
370 ProjectionElem::Field(
371 FieldProjKind::Tuple(closure_arg_count),
372 FieldId::new(i),
373 ),
374 ty,
375 );
376 mk_stt(RawStatement::Assign(
377 locals.place_for_var(LocalId::new(i + 3)),
378 Rvalue::Use(Operand::Move(nth_field)),
379 ))
380 });
381 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
382
383 Ok(body)
384 }
385 Ok(Err(Opaque)) => Err(Opaque),
386 Err(_) => Err(Opaque),
387 }
388 }
389 (FnOnce, Fn | FnMut) => {
399 let hax::FullDefKind::Closure {
401 once_shim: Some(body),
402 ..
403 } = &def.kind
404 else {
405 panic!("missing shim for closure")
406 };
407 let mut bt_ctx = BodyTransCtx::new(&mut self);
408 match bt_ctx.translate_body(span, body, &def.source_text) {
409 Ok(Ok(body)) => Ok(body),
410 Ok(Err(Opaque)) => Err(Opaque),
411 Err(_) => Err(Opaque),
412 }
413 }
414 (FnMut, Fn) => {
421 let fun_id = self.register_closure_method_decl_id(span, def.def_id(), closure_kind);
422 let impl_ref =
423 self.translate_closure_impl_ref(span, def.def_id(), args, closure_kind)?;
424 let fn_op = FnOperand::Regular(FnPtr {
427 func: FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id.clone())).into(),
428 generics: Box::new(impl_ref.generics.concat(
429 GenericsSource::item(fun_id),
430 &GenericArgs {
431 regions: vec![Region::Erased].into(),
432 ..GenericArgs::empty(GenericsSource::item(fun_id))
433 },
434 )),
435 });
436
437 let mut locals = Locals {
438 arg_count: 2,
439 locals: Vector::new(),
440 };
441 let mut statements = vec![];
442 let mut blocks = Vector::default();
443
444 let output = locals.new_var(None, signature.output.clone());
445 let state = locals.new_var(Some("state".to_string()), signature.inputs[0].clone());
446 let args = locals.new_var(Some("args".to_string()), signature.inputs[1].clone());
447 let deref_state = state.deref();
448 let reborrow_ty =
449 TyKind::Ref(Region::Erased, deref_state.ty.clone(), RefKind::Shared).into_ty();
450 let reborrow = locals.new_var(None, reborrow_ty);
451
452 statements.push(mk_stt(RawStatement::Assign(
453 reborrow.clone(),
454 Rvalue::Ref(deref_state, BorrowKind::Shared),
455 )));
456
457 let start_block = blocks.reserve_slot();
458 let ret_block = blocks.push(mk_block(vec![], RawTerminator::Return));
459 let unwind_block = blocks.push(mk_block(vec![], RawTerminator::UnwindResume));
460 let call = RawTerminator::Call {
461 target: ret_block,
462 call: Call {
463 func: fn_op,
464 args: vec![Operand::Move(reborrow), Operand::Move(args)],
465 dest: output,
466 },
467 on_unwind: unwind_block,
468 };
469 blocks.set_slot(start_block, mk_block(statements, call));
470
471 let body: ExprBody = GExprBody {
472 span,
473 locals,
474 comments: vec![],
475 body: blocks,
476 };
477 Ok(Body::Unstructured(body))
478 }
479 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
480 panic!(
481 "Can't make a closure body for a more restrictive kind \
482 than the closure kind"
483 )
484 }
485 })
486 }
487
488 #[tracing::instrument(skip(self, item_meta))]
491 pub fn translate_closure_method(
492 mut self,
493 def_id: FunDeclId,
494 item_meta: ItemMeta,
495 def: &hax::FullDef,
496 target_kind: ClosureKind,
497 ) -> Result<FunDecl, Error> {
498 let span = item_meta.span;
499 let hax::FullDefKind::Closure { args, .. } = &def.kind else {
500 unreachable!()
501 };
502
503 trace!("About to translate closure:\n{:?}", def.def_id);
504
505 self.translate_def_generics(span, def)?;
506 assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
508 self.innermost_binder_mut()
509 .push_params_from_binder(args.tupled_sig.rebind(()))?;
510
511 let impl_ref = self.translate_closure_impl_ref(span, def.def_id(), args, target_kind)?;
512 let implemented_trait =
513 self.translate_closure_trait_ref(span, def.def_id(), args, target_kind)?;
514 let kind = ItemKind::TraitImpl {
515 impl_ref,
516 trait_ref: implemented_trait,
517 item_name: TraitItemName(target_kind.method_name().to_owned()),
518 reuses_default: false,
519 };
520
521 let signature = self.translate_closure_method_sig(def, span, args, target_kind)?;
523
524 let body = if item_meta.opacity.with_private_contents().is_opaque() {
525 Err(Opaque)
526 } else {
527 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
528 };
529
530 Ok(FunDecl {
531 def_id,
532 item_meta,
533 signature,
534 kind,
535 is_global_initializer: None,
536 body,
537 })
538 }
539
540 #[tracing::instrument(skip(self, item_meta))]
541 pub fn translate_closure_trait_impl(
542 mut self,
543 def_id: TraitImplId,
544 item_meta: ItemMeta,
545 def: &hax::FullDef,
546 target_kind: ClosureKind,
547 ) -> Result<TraitImpl, Error> {
548 let span = item_meta.span;
549 let hax::FullDefKind::Closure { args, .. } = &def.kind else {
550 unreachable!()
551 };
552
553 self.translate_def_generics(span, def)?;
554 assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
556 self.innermost_binder_mut()
557 .push_params_from_binder(args.tupled_sig.rebind(()))?;
558
559 let sized_trait = self.get_lang_item(rustc_hir::LangItem::Sized);
561 let sized_trait = self.register_trait_decl_id(span, &sized_trait);
562
563 let tuple_trait = self.get_lang_item(rustc_hir::LangItem::Tuple);
564 let tuple_trait = self.register_trait_decl_id(span, &tuple_trait);
565
566 let implemented_trait =
567 self.translate_closure_trait_ref(span, def.def_id(), args, target_kind)?;
568 let fn_trait = implemented_trait.trait_id;
569
570 let (inputs, output) = self.translate_closure_info(span, args)?.signature.erase();
572 let input = Ty::mk_tuple(inputs);
573
574 let parent_trait_refs = {
575 let builtin_tref = |trait_id, ty| {
577 let generics = Box::new(GenericArgs::new_types(
578 vec![ty].into(),
579 GenericsSource::item(trait_id),
580 ));
581 let trait_decl_ref = TraitDeclRef { trait_id, generics };
582 let trait_decl_ref = RegionBinder::empty(trait_decl_ref);
583 TraitRef {
584 kind: TraitRefKind::BuiltinOrAuto {
585 trait_decl_ref: trait_decl_ref.clone(),
586 parent_trait_refs: Vector::new(),
587 types: vec![],
588 },
589 trait_decl_ref,
590 }
591 };
592
593 match target_kind {
594 ClosureKind::FnOnce => [
595 builtin_tref(sized_trait, input.clone()),
596 builtin_tref(tuple_trait, input.clone()),
597 builtin_tref(sized_trait, output.clone()),
598 ]
599 .into(),
600 ClosureKind::FnMut | ClosureKind::Fn => {
601 let parent_kind = match target_kind {
602 ClosureKind::FnOnce => unreachable!(),
603 ClosureKind::FnMut => ClosureKind::FnOnce,
604 ClosureKind::Fn => ClosureKind::FnMut,
605 };
606 let parent_impl_ref =
607 self.translate_closure_impl_ref(span, def.def_id(), args, parent_kind)?;
608 let parent_predicate =
609 self.translate_closure_trait_ref(span, def.def_id(), args, parent_kind)?;
610 let parent_trait_ref = TraitRef {
611 kind: TraitRefKind::TraitImpl(
612 parent_impl_ref.impl_id,
613 parent_impl_ref.generics,
614 ),
615 trait_decl_ref: RegionBinder::empty(parent_predicate),
616 };
617 [
618 parent_trait_ref,
619 builtin_tref(sized_trait, input.clone()),
620 builtin_tref(tuple_trait, input.clone()),
621 ]
622 .into()
623 }
624 }
625 };
626 let types = match target_kind {
627 ClosureKind::FnOnce => vec![(TraitItemName("Output".into()), output.clone())],
628 ClosureKind::FnMut | ClosureKind::Fn => vec![],
629 };
630
631 let call_fn_id = self.register_closure_method_decl_id(span, &def.def_id, target_kind);
633 let call_fn_name = TraitItemName(target_kind.method_name().to_string());
634 let call_fn_binder = {
635 let mut method_params = GenericParams::empty();
636 match target_kind {
637 ClosureKind::FnOnce => {}
638 ClosureKind::FnMut | ClosureKind::Fn => {
639 method_params
640 .regions
641 .push_with(|index| RegionVar { index, name: None });
642 }
643 };
644
645 let generics = self
646 .outermost_binder()
647 .params
648 .identity_args_at_depth(GenericsSource::item(def_id), DeBruijnId::one())
649 .concat(
650 GenericsSource::item(call_fn_id),
651 &method_params.identity_args_at_depth(
652 GenericsSource::Method(fn_trait, call_fn_name.clone()),
653 DeBruijnId::zero(),
654 ),
655 );
656 Binder::new(
657 BinderKind::TraitMethod(fn_trait, call_fn_name.clone()),
658 method_params,
659 FunDeclRef {
660 id: call_fn_id,
661 generics: Box::new(generics),
662 },
663 )
664 };
665
666 let self_generics = self.into_generics();
667
668 Ok(TraitImpl {
669 def_id,
670 item_meta,
671 impl_trait: implemented_trait,
672 generics: self_generics,
673 parent_trait_refs,
674 type_clauses: vec![],
675 consts: vec![],
676 types,
677 methods: vec![(call_fn_name, call_fn_binder)],
678 })
679 }
680}