charon_driver/translate/
translate_closures.rs

1//! In rust, closures behave like ADTs that implement the FnOnce/FnMut/Fn traits automatically.
2//!
3//! Here we convert closures to a struct containing the closure's state (upvars), along with
4//! matching trait impls and fun decls (e.g. a Fn closure will have a trait impl for Fn, FnMut and
5//! FnOnce, along with 3 matching method implementations for call, call_mut and call_once).
6//!
7//! For example, given the following Rust code:
8//! ```rust
9//! pub fn test_closure_capture<T: Clone>() {
10//!     let mut v = vec![];
11//!     let mut add = |x: &u32| v.push(*x);
12//!     add(&0);
13//!     add(&1);
14//! }
15//! ```
16//!
17//! We generate the equivalent desugared code:
18//! ```rust
19//! struct {test_closure_capture::closure#0}<'a, T: Clone> (&'a mut Vec<u32>);
20//!
21//! // The 'a comes from captured variables, the 'b comes from the closure higher-kinded signature.
22//! impl<'a, 'b, T: Clone> FnMut<(&'b u32,)> for {test_closure_capture::closure#0}<'a, T> {
23//!     fn call_mut<'c>(&'c mut self, arg: (&'b u32,)) {
24//!         self.0.push(*arg.0);
25//!     }
26//! }
27//!
28//! impl<'a, 'b, T: Clone> FnOnce<(&'b u32,)> for {test_closure_capture::closure#0}<'a, T> {
29//!     type Output = ();
30//!     ...
31//! }
32//!
33//! pub fn test_closure_capture<T: Clone>() {
34//!     let mut v = vec![];
35//!     let mut add = {test_closure_capture::closure#0} (&mut v);
36//!     state.call_mut(&0);
37//!     state.call_mut(&1);
38//! }
39//! ```
40
41use std::mem;
42
43use crate::translate::translate_bodies::BodyTransCtx;
44
45use super::translate_ctx::*;
46use charon_lib::ast::*;
47use charon_lib::ids::Vector;
48use charon_lib::ullbc_ast::*;
49use hax_frontend_exporter as hax;
50use itertools::Itertools;
51
52pub fn translate_closure_kind(kind: &hax::ClosureKind) -> ClosureKind {
53    match kind {
54        hax::ClosureKind::Fn => ClosureKind::Fn,
55        hax::ClosureKind::FnMut => ClosureKind::FnMut,
56        hax::ClosureKind::FnOnce => ClosureKind::FnOnce,
57    }
58}
59
60impl ItemTransCtx<'_, '_> {
61    pub fn translate_closure_info(
62        &mut self,
63        span: Span,
64        args: &hax::ClosureArgs,
65    ) -> Result<ClosureInfo, Error> {
66        let kind = translate_closure_kind(&args.kind);
67        let signature = self.translate_region_binder(span, &args.fn_sig, |ctx, sig| {
68            let inputs = sig
69                .inputs
70                .iter()
71                .map(|x| ctx.translate_ty(span, x))
72                .try_collect()?;
73            let output = ctx.translate_ty(span, &sig.output)?;
74            Ok((inputs, output))
75        })?;
76        Ok(ClosureInfo { kind, signature })
77    }
78
79    /// Translate a reference to the closure ADT.
80    pub fn translate_closure_type_ref(
81        &mut self,
82        span: Span,
83        closure: &hax::ClosureArgs,
84    ) -> Result<TypeDeclRef, Error> {
85        let mut tref = self.translate_type_decl_ref(span, &closure.item)?;
86        // We add lifetime args for each borrowing upvar, gotta supply them here.
87        if self.def_id == closure.item.def_id {
88            tref.generics.regions.extend(
89                self.outermost_binder()
90                    .by_ref_upvar_regions
91                    .iter()
92                    .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r))),
93            );
94        } else {
95            tref.generics.regions.extend(
96                closure
97                    .upvar_tys
98                    .iter()
99                    .filter(|ty| {
100                        matches!(
101                            ty.kind(),
102                            hax::TyKind::Ref(
103                                hax::Region {
104                                    kind: hax::RegionKind::ReErased
105                                },
106                                ..
107                            )
108                        )
109                    })
110                    .map(|_| Region::Erased),
111            );
112        }
113        Ok(tref)
114    }
115
116    /// Translate a reference to the chosen closure impl.
117    pub fn translate_closure_impl_ref(
118        &mut self,
119        span: Span,
120        closure: &hax::ClosureArgs,
121        target_kind: ClosureKind,
122    ) -> Result<TraitImplRef, Error> {
123        let impl_id = self.register_closure_trait_impl_id(span, &closure.item.def_id, target_kind);
124        let adt_ref = self.translate_closure_type_ref(span, closure)?;
125        let mut args = adt_ref.generics;
126        // Add the lifetime generics coming from the higher-kindedness of the signature.
127        if self.def_id == closure.item.def_id {
128            args.regions.extend(
129                self.outermost_binder()
130                    .bound_region_vars
131                    .iter()
132                    .map(|r| Region::Var(DeBruijnVar::bound(self.binding_levels.depth(), *r))),
133            );
134        } else {
135            args.regions
136                .extend(closure.fn_sig.bound_vars.iter().map(|_| Region::Erased));
137        }
138
139        Ok(TraitImplRef {
140            id: impl_id,
141            generics: args,
142        })
143    }
144
145    /// Translate a trait reference to the chosen closure impl.
146    pub fn translate_closure_trait_ref(
147        &mut self,
148        span: Span,
149        args: &hax::ClosureArgs,
150        target_kind: ClosureKind,
151    ) -> Result<TraitDeclRef, Error> {
152        // TODO: how much can we ask hax for this?
153        let fn_trait = match target_kind {
154            ClosureKind::FnOnce => self.get_lang_item(rustc_hir::LangItem::FnOnce),
155            ClosureKind::FnMut => self.get_lang_item(rustc_hir::LangItem::FnMut),
156            ClosureKind::Fn => self.get_lang_item(rustc_hir::LangItem::Fn),
157        };
158        let trait_id = self.register_trait_decl_id(span, &fn_trait);
159
160        let state_ty = self.get_closure_state_ty(span, args)?;
161        // The input tuple type and output type of the signature.
162        let (inputs, _) = self.translate_closure_info(span, args)?.signature.erase();
163        let input_tuple = Ty::mk_tuple(inputs);
164
165        Ok(TraitDeclRef {
166            id: trait_id,
167            generics: Box::new(GenericArgs::new_types([state_ty, input_tuple].into())),
168        })
169    }
170
171    pub fn get_closure_state_ty(
172        &mut self,
173        span: Span,
174        args: &hax::ClosureArgs,
175    ) -> Result<Ty, Error> {
176        let tref = self.translate_closure_type_ref(span, args)?;
177        Ok(TyKind::Adt(tref).into_ty())
178    }
179
180    pub fn translate_closure_adt(
181        &mut self,
182        _trans_id: TypeDeclId,
183        span: Span,
184        args: &hax::ClosureArgs,
185    ) -> Result<TypeDeclKind, Error> {
186        let mut by_ref_upvar_regions = self
187            .the_only_binder()
188            .by_ref_upvar_regions
189            .clone()
190            .into_iter();
191        let fields: Vector<FieldId, Field> = args
192            .upvar_tys
193            .iter()
194            .map(|ty| {
195                let mut ty = self.translate_ty(span, ty)?;
196                // We supply fresh regions for the by-ref upvars.
197                if let TyKind::Ref(Region::Erased, deref_ty, kind) = ty.kind() {
198                    let region_id = by_ref_upvar_regions.next().unwrap();
199                    ty = TyKind::Ref(
200                        Region::Var(DeBruijnVar::new_at_zero(region_id)),
201                        deref_ty.clone(),
202                        *kind,
203                    )
204                    .into_ty();
205                }
206                Ok(Field {
207                    span,
208                    attr_info: AttrInfo {
209                        attributes: vec![],
210                        inline: None,
211                        rename: None,
212                        public: false,
213                    },
214                    name: None,
215                    ty,
216                })
217            })
218            .try_collect()?;
219        Ok(TypeDeclKind::Struct(fields))
220    }
221
222    /// Given an item that is a closure, generate the signature of the
223    /// `call_once`/`call_mut`/`call` method (depending on `target_kind`).
224    fn translate_closure_method_sig(
225        &mut self,
226        def: &hax::FullDef,
227        span: Span,
228        args: &hax::ClosureArgs,
229        target_kind: ClosureKind,
230    ) -> Result<FunSig, Error> {
231        let signature = &args.fn_sig;
232        trace!(
233            "signature of closure {:?}:\n{:?}",
234            def.def_id,
235            signature.value,
236        );
237
238        let is_unsafe = match signature.value.safety {
239            hax::Safety::Unsafe => true,
240            hax::Safety::Safe => false,
241        };
242
243        let state_ty = self.get_closure_state_ty(span, args)?;
244
245        // Depending on the kind of the closure generated, add a reference
246        let state_ty = match target_kind {
247            ClosureKind::FnOnce => state_ty,
248            ClosureKind::Fn | ClosureKind::FnMut => {
249                let rid = self
250                    .innermost_generics_mut()
251                    .regions
252                    .push_with(|index| RegionVar { index, name: None });
253                let r = Region::Var(DeBruijnVar::new_at_zero(rid));
254                let mutability = if target_kind == ClosureKind::Fn {
255                    RefKind::Shared
256                } else {
257                    RefKind::Mut
258                };
259                TyKind::Ref(r, state_ty, mutability).into_ty()
260            }
261        };
262
263        // The types that the closure takes as input.
264        let input_tys: Vec<Ty> = signature
265            .value
266            .inputs
267            .iter()
268            .map(|ty| self.translate_ty(span, ty))
269            .try_collect()?;
270        // The method takes `self` and the closure inputs as a tuple.
271        let inputs = vec![state_ty, Ty::mk_tuple(input_tys)];
272        let output = self.translate_ty(span, &signature.value.output)?;
273
274        Ok(FunSig {
275            generics: self.the_only_binder().params.clone(),
276            is_unsafe,
277            inputs,
278            output,
279        })
280    }
281
282    fn translate_closure_method_body(
283        mut self,
284        span: Span,
285        def: &hax::FullDef,
286        target_kind: ClosureKind,
287        args: &hax::ClosureArgs,
288        signature: &FunSig,
289    ) -> Result<Result<Body, Opaque>, Error> {
290        use ClosureKind::*;
291        let closure_kind = translate_closure_kind(&args.kind);
292        let mk_stt = |content| Statement::new(span, content);
293        let mk_block = |statements, terminator| -> BlockData {
294            BlockData {
295                statements,
296                terminator: Terminator::new(span, terminator),
297            }
298        };
299
300        Ok(match (target_kind, closure_kind) {
301            (Fn, Fn) | (FnMut, FnMut) | (FnOnce, FnOnce) => {
302                // Translate the function's body normally
303                let mut bt_ctx = BodyTransCtx::new(&mut self);
304                match bt_ctx.translate_def_body(span, def) {
305                    Ok(Ok(mut body)) => {
306                        // The body is translated as if the locals are: ret value, state, arg-1,
307                        // ..., arg-N, rest...
308                        // However, there is only one argument with the tupled closure arguments;
309                        // we must thus shift all locals with index >=2 by 1, and add a new local
310                        // for the tupled arg, giving us: ret value, state, args, arg-1, ...,
311                        // arg-N, rest...
312                        // We then add N statements of the form `locals[N+3] := move locals[2].N`,
313                        // to destructure the arguments.
314                        let GExprBody {
315                            locals,
316                            body: blocks,
317                            ..
318                        } = body.as_unstructured_mut().unwrap();
319
320                        blocks.dyn_visit_mut(|local: &mut LocalId| {
321                            let idx = local.index();
322                            if idx >= 2 {
323                                *local = LocalId::new(idx + 1)
324                            }
325                        });
326
327                        let mut old_locals = mem::take(&mut locals.locals).into_iter();
328                        locals.arg_count = 2;
329                        locals.locals.push(old_locals.next().unwrap()); // ret
330                        locals.locals.push(old_locals.next().unwrap()); // state
331                        let tupled_arg = locals
332                            .new_var(Some("tupled_args".to_string()), signature.inputs[1].clone());
333                        locals.locals.extend(old_locals.map(|mut l| {
334                            l.index += 1;
335                            l
336                        }));
337
338                        let untupled_args = signature.inputs[1].as_tuple().unwrap();
339                        let closure_arg_count = untupled_args.elem_count();
340                        let new_stts = untupled_args.iter().cloned().enumerate().map(|(i, ty)| {
341                            let nth_field = tupled_arg.clone().project(
342                                ProjectionElem::Field(
343                                    FieldProjKind::Tuple(closure_arg_count),
344                                    FieldId::new(i),
345                                ),
346                                ty,
347                            );
348                            mk_stt(RawStatement::Assign(
349                                locals.place_for_var(LocalId::new(i + 3)),
350                                Rvalue::Use(Operand::Move(nth_field)),
351                            ))
352                        });
353                        blocks[BlockId::ZERO].statements.splice(0..0, new_stts);
354
355                        Ok(body)
356                    }
357                    Ok(Err(Opaque)) => Err(Opaque),
358                    Err(_) => Err(Opaque),
359                }
360            }
361            // Target translation:
362            //
363            // fn call_once(state: Self, args: Args) -> Output {
364            //   let temp_ref = &[mut] state;
365            //   let ret = self.call[_mut](temp, args);
366            //   drop state;
367            //   return ret;
368            // }
369            //
370            (FnOnce, Fn | FnMut) => {
371                // Hax (via rustc) gives us the MIR to do this.
372                let hax::FullDefKind::Closure {
373                    once_shim: Some(body),
374                    ..
375                } = &def.kind
376                else {
377                    panic!("missing shim for closure")
378                };
379                let mut bt_ctx = BodyTransCtx::new(&mut self);
380                match bt_ctx.translate_body(span, body, &def.source_text) {
381                    Ok(Ok(body)) => Ok(body),
382                    Ok(Err(Opaque)) => Err(Opaque),
383                    Err(_) => Err(Opaque),
384                }
385            }
386            // Target translation:
387            //
388            // fn call_mut(state: &mut Self, args: Args) -> Output {
389            //   let reborrow = &*state;
390            //   self.call(reborrow, args)
391            // }
392            (FnMut, Fn) => {
393                let fun_id = self.register_closure_method_decl_id(span, def.def_id(), closure_kind);
394                let impl_ref = self.translate_closure_impl_ref(span, args, closure_kind)?;
395                // TODO: make a trait call to avoid needing to concatenate things ourselves.
396                // TODO: can we ask hax for the trait ref?
397                let fn_op = FnOperand::Regular(FnPtr {
398                    func: FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id.clone())).into(),
399                    generics: Box::new(impl_ref.generics.concat(&GenericArgs {
400                        regions: vec![Region::Erased].into(),
401                        ..GenericArgs::empty()
402                    })),
403                });
404
405                let mut locals = Locals {
406                    arg_count: 2,
407                    locals: Vector::new(),
408                };
409                let mut statements = vec![];
410                let mut blocks = Vector::default();
411
412                let output = locals.new_var(None, signature.output.clone());
413                let state = locals.new_var(Some("state".to_string()), signature.inputs[0].clone());
414                let args = locals.new_var(Some("args".to_string()), signature.inputs[1].clone());
415                let deref_state = state.deref();
416                let reborrow_ty =
417                    TyKind::Ref(Region::Erased, deref_state.ty.clone(), RefKind::Shared).into_ty();
418                let reborrow = locals.new_var(None, reborrow_ty);
419
420                statements.push(mk_stt(RawStatement::Assign(
421                    reborrow.clone(),
422                    Rvalue::Ref(deref_state, BorrowKind::Shared),
423                )));
424
425                let start_block = blocks.reserve_slot();
426                let ret_block = blocks.push(mk_block(vec![], RawTerminator::Return));
427                let unwind_block = blocks.push(mk_block(vec![], RawTerminator::UnwindResume));
428                let call = RawTerminator::Call {
429                    target: ret_block,
430                    call: Call {
431                        func: fn_op,
432                        args: vec![Operand::Move(reborrow), Operand::Move(args)],
433                        dest: output,
434                    },
435                    on_unwind: unwind_block,
436                };
437                blocks.set_slot(start_block, mk_block(statements, call));
438
439                let body: ExprBody = GExprBody {
440                    span,
441                    locals,
442                    comments: vec![],
443                    body: blocks,
444                };
445                Ok(Body::Unstructured(body))
446            }
447            (Fn, FnOnce) | (Fn, FnMut) | (FnMut, FnOnce) => {
448                panic!(
449                    "Can't make a closure body for a more restrictive kind \
450                    than the closure kind"
451                )
452            }
453        })
454    }
455
456    /// Given an item that is a closure, generate the `call_once`/`call_mut`/`call` method
457    /// (depending on `target_kind`).
458    #[tracing::instrument(skip(self, item_meta))]
459    pub fn translate_closure_method(
460        mut self,
461        def_id: FunDeclId,
462        item_meta: ItemMeta,
463        def: &hax::FullDef,
464        target_kind: ClosureKind,
465    ) -> Result<FunDecl, Error> {
466        let span = item_meta.span;
467        let hax::FullDefKind::Closure { args, .. } = &def.kind else {
468            unreachable!()
469        };
470
471        trace!("About to translate closure:\n{:?}", def.def_id);
472
473        self.translate_def_generics(span, def)?;
474        // Add the lifetime generics coming from the higher-kindedness of the signature.
475        assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
476        self.innermost_binder_mut()
477            .push_params_from_binder(args.fn_sig.rebind(()))?;
478
479        let impl_ref = self.translate_closure_impl_ref(span, args, target_kind)?;
480        let implemented_trait = self.translate_closure_trait_ref(span, args, target_kind)?;
481        let kind = ItemKind::TraitImpl {
482            impl_ref,
483            trait_ref: implemented_trait,
484            item_name: TraitItemName(target_kind.method_name().to_owned()),
485            reuses_default: false,
486        };
487
488        // Translate the function signature
489        let signature = self.translate_closure_method_sig(def, span, args, target_kind)?;
490
491        let body = if item_meta.opacity.with_private_contents().is_opaque() {
492            Err(Opaque)
493        } else {
494            self.translate_closure_method_body(span, def, target_kind, args, &signature)?
495        };
496
497        Ok(FunDecl {
498            def_id,
499            item_meta,
500            signature,
501            kind,
502            is_global_initializer: None,
503            body,
504        })
505    }
506
507    #[tracing::instrument(skip(self, item_meta))]
508    pub fn translate_closure_trait_impl(
509        mut self,
510        def_id: TraitImplId,
511        item_meta: ItemMeta,
512        def: &hax::FullDef,
513        target_kind: ClosureKind,
514    ) -> Result<TraitImpl, Error> {
515        let span = item_meta.span;
516        let hax::FullDefKind::Closure { args, .. } = &def.kind else {
517            unreachable!()
518        };
519
520        self.translate_def_generics(span, def)?;
521        // Add the lifetime generics coming from the higher-kindedness of the signature.
522        assert!(self.innermost_binder_mut().bound_region_vars.is_empty(),);
523        self.innermost_binder_mut()
524            .push_params_from_binder(args.fn_sig.rebind(()))?;
525
526        // The builtin traits we need.
527        let sized_trait = self.get_lang_item(rustc_hir::LangItem::Sized);
528        let sized_trait = self.register_trait_decl_id(span, &sized_trait);
529
530        let tuple_trait = self.get_lang_item(rustc_hir::LangItem::Tuple);
531        let tuple_trait = self.register_trait_decl_id(span, &tuple_trait);
532
533        let implemented_trait = self.translate_closure_trait_ref(span, args, target_kind)?;
534        let fn_trait = implemented_trait.id;
535
536        // The input tuple type and output type of the signature.
537        let (inputs, output) = self.translate_closure_info(span, args)?.signature.erase();
538        let input = Ty::mk_tuple(inputs);
539
540        let parent_trait_refs = {
541            // Makes a built-in trait ref for `ty: trait`.
542            let builtin_tref = |trait_id, ty| {
543                let generics = Box::new(GenericArgs::new_types([ty].into()));
544                let trait_decl_ref = TraitDeclRef {
545                    id: trait_id,
546                    generics,
547                };
548                let trait_decl_ref = RegionBinder::empty(trait_decl_ref);
549                TraitRef {
550                    kind: TraitRefKind::BuiltinOrAuto {
551                        trait_decl_ref: trait_decl_ref.clone(),
552                        parent_trait_refs: Vector::new(),
553                        types: vec![],
554                    },
555                    trait_decl_ref,
556                }
557            };
558
559            match target_kind {
560                ClosureKind::FnOnce => [
561                    builtin_tref(sized_trait, input.clone()),
562                    builtin_tref(tuple_trait, input.clone()),
563                    builtin_tref(sized_trait, output.clone()),
564                ]
565                .into(),
566                ClosureKind::FnMut | ClosureKind::Fn => {
567                    let parent_kind = match target_kind {
568                        ClosureKind::FnOnce => unreachable!(),
569                        ClosureKind::FnMut => ClosureKind::FnOnce,
570                        ClosureKind::Fn => ClosureKind::FnMut,
571                    };
572                    let parent_impl_ref =
573                        self.translate_closure_impl_ref(span, args, parent_kind)?;
574                    let parent_predicate =
575                        self.translate_closure_trait_ref(span, args, parent_kind)?;
576                    let parent_trait_ref = TraitRef {
577                        kind: TraitRefKind::TraitImpl(parent_impl_ref),
578                        trait_decl_ref: RegionBinder::empty(parent_predicate),
579                    };
580                    [
581                        parent_trait_ref,
582                        builtin_tref(sized_trait, input.clone()),
583                        builtin_tref(tuple_trait, input.clone()),
584                    ]
585                    .into()
586                }
587            }
588        };
589        let types = match target_kind {
590            ClosureKind::FnOnce => vec![(TraitItemName("Output".into()), output.clone())],
591            ClosureKind::FnMut | ClosureKind::Fn => vec![],
592        };
593
594        // Construct the `call_*` method reference.
595        let call_fn_id = self.register_closure_method_decl_id(span, &def.def_id, target_kind);
596        let call_fn_name = TraitItemName(target_kind.method_name().to_string());
597        let call_fn_binder = {
598            let mut method_params = GenericParams::empty();
599            match target_kind {
600                ClosureKind::FnOnce => {}
601                ClosureKind::FnMut | ClosureKind::Fn => {
602                    method_params
603                        .regions
604                        .push_with(|index| RegionVar { index, name: None });
605                }
606            };
607
608            let generics = self
609                .outermost_binder()
610                .params
611                .identity_args_at_depth(DeBruijnId::one())
612                .concat(&method_params.identity_args_at_depth(DeBruijnId::zero()));
613            Binder::new(
614                BinderKind::TraitMethod(fn_trait, call_fn_name.clone()),
615                method_params,
616                FunDeclRef {
617                    id: call_fn_id,
618                    generics: Box::new(generics),
619                },
620            )
621        };
622
623        let self_generics = self.into_generics();
624
625        Ok(TraitImpl {
626            def_id,
627            item_meta,
628            impl_trait: implemented_trait,
629            generics: self_generics,
630            parent_trait_refs,
631            type_clauses: vec![],
632            consts: vec![],
633            types,
634            methods: vec![(call_fn_name, call_fn_binder)],
635        })
636    }
637}