charon_lib/transform/normalize/
partial_monomorphization.rs

1//! This module implements partial monomorphization, which allows specializing generic items on
2//! some specific instanciation patterns. This is used by Aeneas to avoid nested mutable borrows:
3//! we transform `Iter<'a, &'b mut T>` to `{Iter::<_, &mut U>}<'a, 'b, T>`, where
4//! ```ignore
5//! struct {Iter::<'a, &'b mut U>}<'a, 'b, U> {
6//!   // the field of `Iter` but instantiated with `T -> &'b mut U`.
7//! }
8//! ```
9//!
10//! Note: We may need to partial-mono the same item multiple times: `Foo::<&mut A, B>`, `Foo::<A,
11//! &mut B>`. Note also that partial-mono is infectious: `Foo<Bar<&mut A>>` generates `Bar::<&mut
12//! A>` then `Foo::<Bar::<&mut A>>``.
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19
20use crate::ast::types_utils::TyVisitable;
21use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
22use crate::formatter::IntoFormatter;
23use crate::options::MonomorphizeMut;
24use crate::pretty::FmtWithCtx;
25use crate::register_error;
26use crate::transform::ctx::TransformPass;
27use crate::{transform::TransformCtx, ullbc_ast::*};
28
29type MutabilityShape = Binder<GenericArgs>;
30
31/// See the docs of `MutabilityShapeBuilder::compute_shape`.
32#[derive(Visitor)]
33struct MutabilityShapeBuilder<'pm, 'ctx> {
34    pm: &'pm PartialMonomorphizer<'ctx>,
35    /// The parameters that will constitute the final binder.
36    params: GenericParams,
37    /// The arguments to pass to the final binder to recover the input arguments.
38    extracted: GenericArgs,
39    /// Current depth under which we're visiting.
40    binder_depth: DeBruijnId,
41}
42
43impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
44    /// Compute the mutability "shape" of a set of generic arguments by factoring out the minimal
45    /// amount of information that still allows reconstructing the original arguments while keeping the
46    /// "shape arguments" free of mutable borrows.
47    ///
48    /// For example, for input:
49    ///   <u32, &'a mut &'b A, Option::<&'a mut bool>>
50    /// we want to build:
51    ///   binder<'a, 'b, A, B, C> <A, &'a mut B, Option::<&'b mut C>>
52    /// from which we can recover the original arguments by instantiating it with:
53    ///   <'a, 'a, u32, &'b A, bool>
54    ///
55    /// Formally, given `let (shape, shape_args) = get_mutability_shape(args);`, we have the following:
56    /// - `shape.substitute(shape_args) == args`;
57    /// - `shape_args` contains no infected types;
58    /// - `shape` is as shallow as possible (i.e. takes just enough to get all the infected types
59    ///     and not more).
60    ///
61    /// Note: the input arguments are assumed to have been already partially monomorphized, in the
62    /// sense that we won't recurse inside ADT args because we assume any ADT applied to infected
63    /// args to have been replaced with a fresh infected ADT.
64    fn compute_shape(
65        pm: &'pm PartialMonomorphizer<'ctx>,
66        target_params: &GenericParams,
67        args: &GenericArgs,
68    ) -> (MutabilityShape, GenericArgs) {
69        // We start with the implicit parameters from the original item. We'll need to substitute
70        // them once we've figured out the mapping of explicit parameters, but we'll also be adding
71        // new trait clauses potentially so we can't leave the vector empty (the ids would be
72        // wrong).
73        let mut shape_contents = args.clone();
74        let mut builder = Self {
75            pm,
76            params: GenericParams {
77                regions: IndexMap::new(),
78                types: IndexMap::new(),
79                const_generics: IndexMap::new(),
80                ..target_params.clone()
81            },
82            extracted: GenericArgs {
83                regions: IndexMap::new(),
84                types: IndexMap::new(),
85                const_generics: IndexMap::new(),
86                trait_refs: mem::take(&mut shape_contents.trait_refs),
87            },
88            binder_depth: DeBruijnId::zero(),
89        };
90
91        // Traverse the generics and replace any non-infected type, region or const generic with a
92        // fresh variable.
93        let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
94
95        let shape_params = {
96            let mut shape_params = builder.params;
97            // Now the explicit params in `shape_params` are correct, and the implicit params are a mix
98            // of the old params and new trait clauses. The old params may refer to the old explicit
99            // params which is wrong and must be fixed up.
100            shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
101                if i.index() < target_params.trait_clauses.slot_count() {
102                    x.substitute_explicits(&shape_contents)
103                } else {
104                    x
105                }
106            });
107            shape_params.trait_type_constraints =
108                shape_params.trait_type_constraints.map_indexed(|i, x| {
109                    if i.index() < target_params.trait_type_constraints.slot_count() {
110                        x.substitute_explicits(&shape_contents)
111                    } else {
112                        x
113                    }
114                });
115            shape_params.regions_outlive = shape_params
116                .regions_outlive
117                .into_iter()
118                .enumerate()
119                .map(|(i, x)| {
120                    if i < target_params.regions_outlive.len() {
121                        x.substitute_explicits(&shape_contents)
122                    } else {
123                        x
124                    }
125                })
126                .collect();
127            shape_params.types_outlive = shape_params
128                .types_outlive
129                .into_iter()
130                .enumerate()
131                .map(|(i, x)| {
132                    if i < target_params.types_outlive.len() {
133                        x.substitute_explicits(&shape_contents)
134                    } else {
135                        x
136                    }
137                })
138                .collect();
139            shape_params
140        };
141
142        // The first half of the trait params correspond to the original item clauses so we can
143        // pass them unmodified.
144        shape_contents.trait_refs = shape_params.identity_args().trait_refs;
145        shape_contents
146            .trait_refs
147            .truncate(target_params.trait_clauses.slot_count());
148
149        let shape_args = builder.extracted;
150        let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
151        (shape, shape_args)
152    }
153
154    /// Replace this value with a fresh variable, and record that we did so.
155    fn replace_with_fresh_var<Id, Param, Arg>(
156        &mut self,
157        val: &mut Arg,
158        mk_param: impl FnOnce(Id) -> Param,
159        mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
160    ) where
161        Id: Idx + Display,
162        Arg: TyVisitable + Clone,
163        GenericParams: HasIdxMapOf<Id, Output = Param>,
164        GenericArgs: HasIdxMapOf<Id, Output = Arg>,
165    {
166        let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
167            // Give up on this value.
168            return;
169        };
170        // Record the mapping in the output `GenericArgs`.
171        self.extracted.get_idx_map_mut().push(shifted_val);
172        // Put a fresh param in place of `val`.
173        let id = self.params.get_idx_map_mut().push_with(mk_param);
174        *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
175    }
176}
177
178impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
179    fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
180        &mut self.binder_depth
181    }
182}
183
184impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
185    fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
186        VisitWithBinderDepth::new(self).visit(x)
187    }
188
189    fn enter_ty(&mut self, ty: &mut Ty) {
190        if !self.pm.is_infected(ty) {
191            self.replace_with_fresh_var(
192                ty,
193                |id| TypeParam::new(id, format!("T{id}")),
194                |v| v.into(),
195            );
196        }
197    }
198    fn exit_ty_kind(&mut self, kind: &mut TyKind) {
199        if let TyKind::Adt(TypeDeclRef {
200            id: TypeId::Adt(id),
201            generics,
202        }) = kind
203        {
204            // Since the type was not replaced with a type var, it's an infected type. We've
205            // traversed it so we have its final explicit arguments. Now we need to satisfy its
206            // predicates. For that we add all its predicates to the new item, and pass those new
207            // trait clauses to it.
208            let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
209                return;
210            };
211            let Some(shifted_generics) =
212                generics.clone().move_from_under_binders(self.binder_depth)
213            else {
214                // Give up on this value.
215                return;
216            };
217
218            // Add the target predicates (properly substituted) to the new item params.
219            let num_clauses_before_merge = self.params.trait_clauses.slot_count();
220            self.params.merge_predicates_from(
221                target_params
222                    .clone()
223                    .substitute_explicits(&shifted_generics),
224            );
225
226            // Record the trait arguments in the output `GenericArgs`.
227            self.extracted
228                .trait_refs
229                .extend(shifted_generics.trait_refs);
230
231            // Replace each trait ref with a clause var.
232            for (target_clause_id, tref) in generics.trait_refs.iter_mut_indexed() {
233                let clause_id = target_clause_id + num_clauses_before_merge;
234                *tref =
235                    self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
236            }
237        }
238    }
239    fn enter_region(&mut self, r: &mut Region) {
240        self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
241    }
242    // TODO: we're missing type info for this
243    // fn enter_const_generic(&mut self, cg: &mut ConstGeneric) {
244    //     self.replace_with_fresh_var(cg, |id| {
245    //         ConstGenericParam::new(id, format!("N{id}"), cg.ty().clone())
246    //     });
247    // }
248    fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
249        // We don't touch trait refs or we'd risk adding duplicated extra params. Instead, we fix
250        // them up in `exit_ty_kind` and `compute_shape`.
251        ControlFlow::Continue(())
252    }
253}
254
255#[derive(Visitor)]
256struct PartialMonomorphizer<'a> {
257    ctx: &'a mut TransformCtx,
258    /// Tracks the closest span to emit useful errors.
259    span: Span,
260    /// Whether we partially-monomorphize type declarations.
261    instantiate_types: bool,
262    /// Types that contain mutable references.
263    infected_types: HashSet<TypeDeclId>,
264    /// Map of generic params for each item. We can't use `ctx.translated` because while iterating
265    /// over items the current item isn't available anymore, which would break recursive types.
266    /// This also makes it possible to record the generics of our to-be-added items without adding
267    /// them.
268    generic_params: HashMap<ItemId, GenericParams>,
269    /// Map of partial monomorphizations. The source item applied with the generic params gives the
270    /// target item. The resulting partially-monomorphized item will have the binder params as
271    /// generic params.
272    partial_mono_shapes: SeqHashMap<(ItemId, MutabilityShape), ItemId>,
273    /// Reverse of `partial_mono_shapes`.
274    reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
275    /// Items that need to be processed.
276    to_process: VecDeque<ItemId>,
277}
278
279impl<'a> PartialMonomorphizer<'a> {
280    pub fn new(ctx: &'a mut TransformCtx, instantiate_types: bool) -> Self {
281        use petgraph::graphmap::DiGraphMap;
282        use petgraph::visit::Dfs;
283        use petgraph::visit::Walker;
284
285        // Compute the types that contain `&mut` (even indirectly).
286        let infected_types: HashSet<_> = {
287            // We build a graph that has one node per type decl plus a special `None` node. If type A
288            // contains a reference to type B we add a B->A edge; if it contains a mutable reference we
289            // add a None->A edge. Then a type contains `&mut` iff it is reachable from the `None`
290            // node.
291            let mut graph: DiGraphMap<Option<TypeDeclId>, ()> = Default::default();
292            for (id, tdecl) in ctx.translated.type_decls.iter_indexed() {
293                tdecl.dyn_visit(|x: &Ty| match x.kind() {
294                    TyKind::Ref(_, _, RefKind::Mut) => {
295                        graph.add_edge(None, Some(id), ());
296                    }
297                    TyKind::Adt(tref) if let TypeId::Adt(other_id) = tref.id => {
298                        graph.add_edge(Some(other_id), Some(id), ());
299                    }
300                    _ => {}
301                });
302            }
303            let start = graph.add_node(None);
304            Dfs::new(&graph, start)
305                .iter(&graph)
306                .filter_map(|opt_id| opt_id)
307                .collect()
308        };
309
310        // Record the generic params of all items.
311        let generic_params: HashMap<ItemId, GenericParams> = ctx
312            .translated
313            .all_items()
314            .map(|item| (item.id(), item.generic_params().clone()))
315            .collect();
316
317        // Enqueue all items to be processed.
318        let to_process = ctx.translated.all_ids().collect();
319        PartialMonomorphizer {
320            ctx,
321            span: Span::dummy(),
322            instantiate_types,
323            infected_types,
324            generic_params,
325            to_process,
326            partial_mono_shapes: SeqHashMap::default(),
327            reverse_shape_map: Default::default(),
328        }
329    }
330
331    /// Whether this type is or contains a `&mut`. This assumes that we've already visited this
332    /// type and partially monomorphized any ADT references.
333    fn is_infected(&self, ty: &Ty) -> bool {
334        match ty.kind() {
335            TyKind::Ref(_, _, RefKind::Mut) => true,
336            TyKind::Ref(_, ty, _) | TyKind::RawPtr(ty, _) => self.is_infected(ty),
337            TyKind::Adt(tref) => {
338                let ty_infected =
339                    matches!(&tref.id, TypeId::Adt(id) if self.infected_types.contains(id));
340                let args_infected = if tref.id.is_adt() && self.instantiate_types {
341                    // Since we make sure to only call the method on a processed type, any type
342                    // with infected arguments would have been replaced with a fresh instantiated
343                    // (and infected type). Hence we don't need to check the arguments here, only
344                    // the type id.
345                    false
346                } else {
347                    tref.generics.types.iter().any(|ty| self.is_infected(ty))
348                };
349                ty_infected || args_infected
350            }
351            TyKind::FnDef(..) | TyKind::FnPtr(..) => {
352                register_error!(
353                    self.ctx,
354                    self.span,
355                    "function pointers are unsupported with `--monomorphize-mut`"
356                );
357                false
358            }
359            TyKind::DynTrait(_) => {
360                register_error!(
361                    self.ctx,
362                    self.span,
363                    "`dyn Trait` is unsupported with `--monomorphize-mut`"
364                );
365                false
366            }
367            TyKind::TypeVar(..)
368            | TyKind::Literal(..)
369            | TyKind::Never
370            | TyKind::TraitType(..)
371            | TyKind::PtrMetadata(..)
372            | TyKind::Error(_) => false,
373        }
374    }
375
376    /// Given that `generics` apply to item `id`, if any of the generics is infected we generate a
377    /// reference to a new item obtained by partially instantiating item `id`. (That new item isn't
378    /// added immediately but is added to the `to_process` queue to be created later).
379    fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
380        if !generics.types.iter().any(|ty| self.is_infected(ty)) {
381            return None;
382        }
383
384        // If the type is already an instantiation, transform this reference into a reference to
385        // the original type so we don't instantiate the instantiation.
386        let mut new_generics;
387        let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
388            new_generics = shape.clone().apply(generics);
389            let _ = self.visit(&mut new_generics); // New instantiation may require cleanup.
390            (base_id, &new_generics)
391        } else {
392            (id, generics)
393        };
394
395        // Split the args between the infected part and the non-infected part.
396        let item_params = self.generic_params.get(&id)?;
397        let (shape, shape_args) =
398            MutabilityShapeBuilder::compute_shape(self, item_params, generics);
399
400        // Create a new type id.
401        let new_params = shape.params.clone();
402        let key: (ItemId, MutabilityShape) = (id, shape);
403        let new_id = *self
404            .partial_mono_shapes
405            .entry(key.clone())
406            .or_insert_with(|| {
407                let new_id = match id {
408                    ItemId::Type(_) => {
409                        let new_id = self.ctx.translated.type_decls.reserve_slot();
410                        self.infected_types.insert(new_id);
411                        new_id.into()
412                    }
413                    ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
414                    ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
415                    ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
416                    ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
417                };
418                self.generic_params.insert(new_id, new_params);
419                self.reverse_shape_map.insert(new_id, key);
420                self.to_process.push_back(new_id);
421                new_id
422            });
423
424        let fmt_ctx = self.ctx.into_fmt();
425        trace!(
426            "processing {}{}\n output: {}{}",
427            id.with_ctx(&fmt_ctx),
428            generics.with_ctx(&fmt_ctx),
429            new_id.with_ctx(&fmt_ctx),
430            shape_args.with_ctx(&fmt_ctx),
431        );
432        Some(DeclRef {
433            id: new_id,
434            generics: Box::new(shape_args),
435            trait_ref: None,
436        })
437    }
438
439    /// Traverse the item, replacing any type instantiations we don't want with references to
440    /// soon-to-be-created partially-monomorphized types. This does not access the items in
441    /// `self.translated`, which may be missing since we took `item` out for processing.
442    pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
443        let _ = item.drive_mut(self);
444    }
445
446    /// Creates the item corresponding to this id by instantiating the item it is based on.
447    ///
448    /// This accesses the items in `self.translated`, which must therefore all be there.
449    /// That's why items are created outside of `process_item`.
450    pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
451        let (orig_id, shape) = &self.reverse_shape_map[&new_id];
452        let mut decl = self
453            .ctx
454            .translated
455            .get_item(*orig_id)
456            .unwrap()
457            .clone()
458            .substitute(&shape.skip_binder);
459
460        let mut decl_mut = decl.as_mut();
461        decl_mut.set_id(new_id);
462        *decl_mut.generic_params() = shape.params.clone();
463
464        let name_ref = &mut decl_mut.item_meta().name;
465        *name_ref = mem::take(name_ref).instantiate(shape.clone());
466        self.ctx
467            .translated
468            .item_names
469            .insert(new_id, decl.as_ref().item_meta().name.clone());
470
471        decl
472    }
473}
474
475impl VisitorWithSpan for PartialMonomorphizer<'_> {
476    fn current_span(&mut self) -> &mut Span {
477        &mut self.span
478    }
479}
480impl VisitAstMut for PartialMonomorphizer<'_> {
481    fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
482        // Track a useful enclosing span, for error messages.
483        VisitWithSpan::new(self).visit(x)
484    }
485
486    fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
487        if self.instantiate_types
488            && let TypeId::Adt(id) = x.id
489            && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
490        {
491            *x = new_decl_ref.try_into().unwrap()
492        }
493    }
494    fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
495        // TODO: methods. any `Trait::method<&mut A>` requires monomorphizing all the instances of
496        // that method just in case :>>>
497        if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
498            && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
499        {
500            *x = new_decl_ref.try_into().unwrap()
501        }
502    }
503    fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
504        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
505            *x = new_decl_ref.try_into().unwrap()
506        }
507    }
508    fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
509        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
510            *x = new_decl_ref.try_into().unwrap()
511        }
512    }
513    fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
514        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
515            *x = new_decl_ref.try_into().unwrap()
516        }
517    }
518    fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
519        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
520            *x = new_decl_ref.try_into().unwrap()
521        }
522    }
523}
524
525pub struct Transform;
526impl TransformPass for Transform {
527    fn transform_ctx(&self, ctx: &mut TransformCtx) {
528        let Some(include_types) = ctx.options.monomorphize_mut else {
529            return;
530        };
531        // TODO: test name matcher, also with methods
532        let mut visitor =
533            PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
534        while let Some(id) = visitor.to_process.pop_front() {
535            // Get the item corresponding to this id, either by creating it or by getting an
536            // existing one.
537            let mut decl = if visitor.reverse_shape_map.get(&id).is_some() {
538                // Create the required item by instantiating the item it's based on.
539                visitor.create_pending_instantiation(id)
540            } else {
541                // Take the item out so we can modify it. Warning: don't look up other items in the
542                // meantime as this would break in recursive cases.
543                match visitor.ctx.translated.remove_item(id) {
544                    Some(decl) => decl,
545                    None => continue,
546                }
547            };
548            // Visit the item, replacing type instantiations with references to soon-to-be-created
549            // partially-monomorphized types.
550            visitor.process_item(&mut decl.as_mut());
551            // Put the item back.
552            visitor.ctx.translated.set_item_slot(id, decl);
553        }
554    }
555}