1use std::mem;
42
43use crate::translate::translate_bodies::BodyTransCtx;
44
45use super::translate_ctx::*;
46use charon_lib::ast::*;
47use charon_lib::ids::Vector;
48use charon_lib::ullbc_ast::*;
49use hax_frontend_exporter as hax;
50use itertools::Itertools;
51
52pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
53 match kind {
54 hax::ClosureKind::Fn => ClosureKind::Fn,
55 hax::ClosureKind::FnMut => ClosureKind::FnMut,
56 hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
57 }
58}
59
60impl ItemTransCtx<'_, '_> {
61 pub fn translate_closure_info(
62 &mut self,
63 span: Span,
64 args: &hax::ClosureArgs,
65 ) -> Result<ClosureInfo, Error> {
66 let kind = translate_closure_kind(&args.kind);
67 let signature = self.translate_region_binder(span, &args.fn_sig, |ctx, sig| {
68 let inputs = sig
69 .inputs
70 .iter()
71 .map(|x| ctx.translate_ty(span, x))
72 .try_collect()?;
73 let output = ctx.translate_ty(span, &sig.output)?;
74 Ok((inputs, output))
75 })?;
76 Ok(ClosureInfo { kind, signature })
77 }
78
79 pub fn translate_closure_type_ref(
81 &mut self,
82 span: Span,
83 closure: &hax::ClosureArgs,
84 ) -> Result<TypeDeclRef, Error> {
85 let mut tref = self.translate_type_decl_ref(span, &closure.item)?;
86 if self.def_id == closure.item.def_id {
88 tref.generics.regions.extend(
89 self.outermost_binder()
90 .by_ref_upvar_regions
91 .iter()
92 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r))),
93 );
94 } else {
95 tref.generics.regions.extend(
96 closure
97 .upvar_tys
98 .iter()
99 .filter(|ty| {
100 matches!(
101 ty.kind(),
102 hax::TyKind::Ref(
103 hax::Region {
104 kind: hax::RegionKind::ReErased
105 },
106 ..
107 )
108 )
109 })
110 .map(|_| Region::Erased),
111 );
112 }
113 Ok(tref)
114 }
115
116 pub fn translate_closure_impl_ref(
118 &mut self,
119 span: Span,
120 closure: &hax::ClosureArgs,
121 target_kind: ClosureKind,
122 ) -> Result<TraitImplRef, Error> {
123 let impl_id = self.register_closure_trait_impl_id(span, &closure.item.def_id, target_kind);
124 let adt_ref = self.translate_closure_type_ref(span, closure)?;
125 let mut args = adt_ref.generics;
126 if self.def_id == closure.item.def_id {
128 args.regions.extend(
129 self.outermost_binder()
130 .bound_region_vars
131 .iter()
132 .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r))),
133 );
134 } else {
135 args.regions
136 .extend(closure.fn_sig.bound_vars.iter().map(|_| Region::Erased));
137 }
138
139 Ok(TraitImplRef {
140 id: impl_id,
141 generics: args,
142 })
143 }
144
145 pub fn translate_closure_trait_ref(
147 &mut self,
148 span: Span,
149 args: &hax::ClosureArgs,
150 target_kind: ClosureKind,
151 ) -> Result<TraitDeclRef, Error> {
152 let fn_trait = match target_kind {
154 ClosureKind::FnOnce => self.get_lang_item(rustc_hir::LangItem::FnOnce),
155 ClosureKind::FnMut => self.get_lang_item(rustc_hir::LangItem::FnMut),
156 ClosureKind::Fn => self.get_lang_item(rustc_hir::LangItem::Fn),
157 };
158 let trait_id = self.register_trait_decl_id(span, &fn_trait);
159
160 let state_ty = self.get_closure_state_ty(span, args)?;
161 let (inputs, _) = self.translate_closure_info(span, args)?.signature.erase();
163 let input_tuple = Ty::mk_tuple(inputs);
164
165 Ok(TraitDeclRef {
166 id: trait_id,
167 generics: Box::new(GenericArgs::new_types([state_ty, input_tuple].into())),
168 })
169 }
170
171 pub fn get_closure_state_ty(
172 &mut self,
173 span: Span,
174 args: &hax::ClosureArgs,
175 ) -> Result<Ty, Error> {
176 let tref = self.translate_closure_type_ref(span, args)?;
177 Ok(TyKind::Adt(tref).into_ty())
178 }
179
180 pub fn translate_closure_adt(
181 &mut self,
182 _trans_id: TypeDeclId,
183 span: Span,
184 args: &hax::ClosureArgs,
185 ) -> Result<TypeDeclKind, Error> {
186 let mut by_ref_upvar_regions = self
187 .the_only_binder()
188 .by_ref_upvar_regions
189 .clone()
190 .into_iter();
191 let fields: Vector<FieldId, Field> = args
192 .upvar_tys
193 .iter()
194 .map(|ty| {
195 let mut ty = self.translate_ty(span, ty)?;
196 if let TyKind::Ref(Region::Erased, deref_ty, kind) = ty.kind() {
198 let region_id = by_ref_upvar_regions.next().unwrap();
199 ty = TyKind::Ref(
200 Region::Var(DeBruijnVar::new_at_zero(region_id)),
201 deref_ty.clone(),
202 *kind,
203 )
204 .into_ty();
205 }
206 Ok(Field {
207 span,
208 attr_info: AttrInfo {
209 attributes: vec![],
210 inline: None,
211 rename: None,
212 public: false,
213 },
214 name: None,
215 ty,
216 })
217 })
218 .try_collect()?;
219 Ok(TypeDeclKind::Struct(fields))
220 }
221
222 fn translate_closure_method_sig(
225 &mut self,
226 def: &hax::FullDef,
227 span: Span,
228 args: &hax::ClosureArgs,
229 target_kind: ClosureKind,
230 ) -> Result<FunSig, Error> {
231 let signature = &args.fn_sig;
232 trace!(
233 "signature of closure {:?}:\n{:?}",
234 def.def_id,
235 signature.value,
236 );
237
238 let is_unsafe = match signature.value.safety {
239 hax::Safety::Unsafe => true,
240 hax::Safety::Safe => false,
241 };
242
243 let state_ty = self.get_closure_state_ty(span, args)?;
244
245 let state_ty = match target_kind {
247 ClosureKind::FnOnce => state_ty,
248 ClosureKind::Fn | ClosureKind::FnMut => {
249 let rid = self
250 .innermost_generics_mut()
251 .regions
252 .push_with(|index| RegionVar { index, name: None });
253 let r = Region::Var(DeBruijnVar::new_at_zero(rid));
254 let mutability = if target_kind == ClosureKind::Fn {
255 RefKind::Shared
256 } else {
257 RefKind::Mut
258 };
259 TyKind::Ref(r, state_ty, mutability).into_ty()
260 }
261 };
262
263 let input_tys: Vec<Ty> = signature
265 .value
266 .inputs
267 .iter()
268 .map(|ty| self.translate_ty(span, ty))
269 .try_collect()?;
270 let inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
272 let output = self.translate_ty(span, &signature.value.output)?;
273
274 Ok(FunSig {
275 generics: self.the_only_binder().params.clone(),
276 is_unsafe,
277 inputs,
278 output,
279 })
280 }
281
282 fn translate_closure_method_body(
283 mut self,
284 span: Span,
285 def: &hax::FullDef,
286 target_kind: ClosureKind,
287 args: &hax::ClosureArgs,
288 signature: &FunSig,
289 ) -> Result<Result<Body, Opaque>, Error> {
290 use ClosureKind::*;
291 let closure_kind = translate_closure_kind(&args.kind);
292 let mk_stt = |content| Statement::new(span, content);
293 let mk_block = |statements, terminator| -> BlockData {
294 BlockData {
295 statements,
296 terminator: Terminator::new(span, terminator),
297 }
298 };
299
300 Ok(match (target_kind, closure_kind) {
301 (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
302 let mut bt_ctx = BodyTransCtx::new(&mut self);
304 match bt_ctx.translate_def_body(span, def) {
305 Ok(Ok(mut body)) => {
306 let GExprBody {
315 locals,
316 body: blocks,
317 ..
318 } = body.as_unstructured_mut().unwrap();
319
320 blocks.dyn_visit_mut(|local: &mut LocalId| {
321 let idx = local.index();
322 if idx >= 2 {
323 *local = LocalId::new(idx + 1)
324 }
325 });
326
327 let mut old_locals = mem::take(&mut locals.locals).into_iter();
328 locals.arg_count = 2;
329 locals.locals.push(old_locals.next().unwrap()); locals.locals.push(old_locals.next().unwrap()); let tupled_arg = locals
332 .new_var(Some("tupled_args".to_string()), signature.inputs[1].clone());
333 locals.locals.extend(old_locals.map(|mut l| {
334 l.index += 1;
335 l
336 }));
337
338 let untupled_args = signature.inputs[1].as_tuple().unwrap();
339 let closure_arg_count = untupled_args.elem_count();
340 let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
341 let nth_field = tupled_arg.clone().project(
342 ProjectionElem::Field(
343 FieldProjKind::Tuple(closure_arg_count),
344 FieldId::new(i),
345 ),
346 ty,
347 );
348 mk_stt(RawStatement::Assign(
349 locals.place_for_var(LocalId::new(i + 3)),
350 Rvalue::Use(Operand::Move(nth_field)),
351 ))
352 });
353 blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
354
355 Ok(body)
356 }
357 Ok(Err(Opaque)) => Err(Opaque),
358 Err(_) => Err(Opaque),
359 }
360 }
361 (FnOnce, Fn | FnMut) => {
371 let hax::FullDefKind::Closure {
373 once_shim: Some(body),
374 ..
375 } = &def.kind
376 else {
377 panic!("missing shim for closure")
378 };
379 let mut bt_ctx = BodyTransCtx::new(&mut self);
380 match bt_ctx.translate_body(span, body, &def.source_text) {
381 Ok(Ok(body)) => Ok(body),
382 Ok(Err(Opaque)) => Err(Opaque),
383 Err(_) => Err(Opaque),
384 }
385 }
386 (FnMut, Fn) => {
393 let fun_id = self.register_closure_method_decl_id(span, def.def_id(), closure_kind);
394 let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
395 let fn_op = FnOperand::Regular(FnPtr {
398 func: FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id.clone())).into(),
399 generics: Box::new(impl_ref.generics.concat(&GenericArgs {
400 regions: vec![Region::Erased].into(),
401 ..GenericArgs::empty()
402 })),
403 });
404
405 let mut locals = Locals {
406 arg_count: 2,
407 locals: Vector::new(),
408 };
409 let mut statements = vec![];
410 let mut blocks = Vector::default();
411
412 let output = locals.new_var(None, signature.output.clone());
413 let state = locals.new_var(Some("state".to_string()), signature.inputs[0].clone());
414 let args = locals.new_var(Some("args".to_string()), signature.inputs[1].clone());
415 let deref_state = state.deref();
416 let reborrow_ty =
417 TyKind::Ref(Region::Erased, deref_state.ty.clone(), RefKind::Shared).into_ty();
418 let reborrow = locals.new_var(None, reborrow_ty);
419
420 statements.push(mk_stt(RawStatement::Assign(
421 reborrow.clone(),
422 Rvalue::Ref(deref_state, BorrowKind::Shared),
423 )));
424
425 let start_block = blocks.reserve_slot();
426 let ret_block = blocks.push(mk_block(vec![], RawTerminator::Return));
427 let unwind_block = blocks.push(mk_block(vec![], RawTerminator::UnwindResume));
428 let call = RawTerminator::Call {
429 target: ret_block,
430 call: Call {
431 func: fn_op,
432 args: vec![Operand::Move(reborrow), Operand::Move(args)],
433 dest: output,
434 },
435 on_unwind: unwind_block,
436 };
437 blocks.set_slot(start_block, mk_block(statements, call));
438
439 let body: ExprBody = GExprBody {
440 span,
441 locals,
442 comments: vec![],
443 body: blocks,
444 };
445 Ok(Body::Unstructured(body))
446 }
447 (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
448 panic!(
449 "Can't make a closure body for a more restrictive kind \
450 than the closure kind"
451 )
452 }
453 })
454 }
455
456 #[tracing::instrument(skip(self, item_meta))]
459 pub fn translate_closure_method(
460 mut self,
461 def_id: FunDeclId,
462 item_meta: ItemMeta,
463 def: &hax::FullDef,
464 target_kind: ClosureKind,
465 ) -> Result<FunDecl, Error> {
466 let span = item_meta.span;
467 let hax::FullDefKind::Closure { args, .. } = &def.kind else {
468 unreachable!()
469 };
470
471 trace!("About to translate closure:\n{:?}", def.def_id);
472
473 self.translate_def_generics(span, def)?;
474 assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
476 self.innermost_binder_mut()
477 .push_params_from_binder(args.fn_sig.rebind(()))?;
478
479 let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
480 let implemented_trait = self.translate_closure_trait_ref(span, args, target_kind)?;
481 let kind = ItemKind::TraitImpl {
482 impl_ref,
483 trait_ref: implemented_trait,
484 item_name: TraitItemName(target_kind.method_name().to_owned()),
485 reuses_default: false,
486 };
487
488 let signature = self.translate_closure_method_sig(def, span, args, target_kind)?;
490
491 let body = if item_meta.opacity.with_private_contents().is_opaque() {
492 Err(Opaque)
493 } else {
494 self.translate_closure_method_body(span, def, target_kind, args, &signature)?
495 };
496
497 Ok(FunDecl {
498 def_id,
499 item_meta,
500 signature,
501 kind,
502 is_global_initializer: None,
503 body,
504 })
505 }
506
507 #[tracing::instrument(skip(self, item_meta))]
508 pub fn translate_closure_trait_impl(
509 mut self,
510 def_id: TraitImplId,
511 item_meta: ItemMeta,
512 def: &hax::FullDef,
513 target_kind: ClosureKind,
514 ) -> Result<TraitImpl, Error> {
515 let span = item_meta.span;
516 let hax::FullDefKind::Closure { args, .. } = &def.kind else {
517 unreachable!()
518 };
519
520 self.translate_def_generics(span, def)?;
521 assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
523 self.innermost_binder_mut()
524 .push_params_from_binder(args.fn_sig.rebind(()))?;
525
526 let sized_trait = self.get_lang_item(rustc_hir::LangItem::Sized);
528 let sized_trait = self.register_trait_decl_id(span, &sized_trait);
529
530 let tuple_trait = self.get_lang_item(rustc_hir::LangItem::Tuple);
531 let tuple_trait = self.register_trait_decl_id(span, &tuple_trait);
532
533 let implemented_trait = self.translate_closure_trait_ref(span, args, target_kind)?;
534 let fn_trait = implemented_trait.id;
535
536 let (inputs, output) = self.translate_closure_info(span, args)?.signature.erase();
538 let input = Ty::mk_tuple(inputs);
539
540 let parent_trait_refs = {
541 let builtin_tref = |trait_id, ty| {
543 let generics = Box::new(GenericArgs::new_types([ty].into()));
544 let trait_decl_ref = TraitDeclRef {
545 id: trait_id,
546 generics,
547 };
548 let trait_decl_ref = RegionBinder::empty(trait_decl_ref);
549 TraitRef {
550 kind: TraitRefKind::BuiltinOrAuto {
551 trait_decl_ref: trait_decl_ref.clone(),
552 parent_trait_refs: Vector::new(),
553 types: vec![],
554 },
555 trait_decl_ref,
556 }
557 };
558
559 match target_kind {
560 ClosureKind::FnOnce => [
561 builtin_tref(sized_trait, input.clone()),
562 builtin_tref(tuple_trait, input.clone()),
563 builtin_tref(sized_trait, output.clone()),
564 ]
565 .into(),
566 ClosureKind::FnMut | ClosureKind::Fn => {
567 let parent_kind = match target_kind {
568 ClosureKind::FnOnce => unreachable!(),
569 ClosureKind::FnMut => ClosureKind::FnOnce,
570 ClosureKind::Fn => ClosureKind::FnMut,
571 };
572 let parent_impl_ref =
573 self.translate_closure_impl_ref(span, args, parent_kind)?;
574 let parent_predicate =
575 self.translate_closure_trait_ref(span, args, parent_kind)?;
576 let parent_trait_ref = TraitRef {
577 kind: TraitRefKind::TraitImpl(parent_impl_ref),
578 trait_decl_ref: RegionBinder::empty(parent_predicate),
579 };
580 [
581 parent_trait_ref,
582 builtin_tref(sized_trait, input.clone()),
583 builtin_tref(tuple_trait, input.clone()),
584 ]
585 .into()
586 }
587 }
588 };
589 let types = match target_kind {
590 ClosureKind::FnOnce => vec![(TraitItemName("Output".into()), output.clone())],
591 ClosureKind::FnMut | ClosureKind::Fn => vec![],
592 };
593
594 let call_fn_id = self.register_closure_method_decl_id(span, &def.def_id, target_kind);
596 let call_fn_name = TraitItemName(target_kind.method_name().to_string());
597 let call_fn_binder = {
598 let mut method_params = GenericParams::empty();
599 match target_kind {
600 ClosureKind::FnOnce => {}
601 ClosureKind::FnMut | ClosureKind::Fn => {
602 method_params
603 .regions
604 .push_with(|index| RegionVar { index, name: None });
605 }
606 };
607
608 let generics = self
609 .outermost_binder()
610 .params
611 .identity_args_at_depth(DeBruijnId::one())
612 .concat(&method_params.identity_args_at_depth(DeBruijnId::zero()));
613 Binder::new(
614 BinderKind::TraitMethod(fn_trait, call_fn_name.clone()),
615 method_params,
616 FunDeclRef {
617 id: call_fn_id,
618 generics: Box::new(generics),
619 },
620 )
621 };
622
623 let self_generics = self.into_generics();
624
625 Ok(TraitImpl {
626 def_id,
627 item_meta,
628 impl_trait: implemented_trait,
629 generics: self_generics,
630 parent_trait_refs,
631 type_clauses: vec![],
632 consts: vec![],
633 types,
634 methods: vec![(call_fn_name, call_fn_binder)],
635 })
636 }
637}