charon_lib/transform/normalize/
partial_monomorphization.rs

1//! This module implements partial monomorphization, which allows specializing generic items on
2//! some specific instanciation patterns. This is used by Aeneas to avoid nested mutable borrows:
3//! we transform `Iter<'a, &'b mut T>` to `{Iter::<_, &mut U>}<'a, 'b, T>`, where
4//! ```ignore
5//! struct {Iter::<'a, &'b mut U>}<'a, 'b, U> {
6//!   // the field of `Iter` but instantiated with `T -> &'b mut U`.
7//! }
8//! ```
9//!
10//! Note: We may need to partial-mono the same item multiple times: `Foo::<&mut A, B>`, `Foo::<A,
11//! &mut B>`. Note also that partial-mono is infectious: `Foo<Bar<&mut A>>` generates `Bar::<&mut
12//! A>` then `Foo::<Bar::<&mut A>>``.
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19use indexmap::IndexMap;
20
21use crate::ast::types_utils::TyVisitable;
22use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
23use crate::formatter::IntoFormatter;
24use crate::options::MonomorphizeMut;
25use crate::pretty::FmtWithCtx;
26use crate::register_error;
27use crate::transform::ctx::TransformPass;
28use crate::{transform::TransformCtx, ullbc_ast::*};
29
30type MutabilityShape = Binder<GenericArgs>;
31
32/// See the docs of `MutabilityShapeBuilder::compute_shape`.
33#[derive(Visitor)]
34struct MutabilityShapeBuilder<'pm, 'ctx> {
35    pm: &'pm PartialMonomorphizer<'ctx>,
36    /// The parameters that will constitute the final binder.
37    params: GenericParams,
38    /// The arguments to pass to the final binder to recover the input arguments.
39    extracted: GenericArgs,
40    /// Current depth under which we're visiting.
41    binder_depth: DeBruijnId,
42}
43
44impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
45    /// Compute the mutability "shape" of a set of generic arguments by factoring out the minimal
46    /// amount of information that still allows reconstructing the original arguments while keeping the
47    /// "shape arguments" free of mutable borrows.
48    ///
49    /// For example, for input:
50    ///   <u32, &'a mut &'b A, Option::<&'a mut bool>>
51    /// we want to build:
52    ///   binder<'a, 'b, A, B, C> <A, &'a mut B, Option::<&'b mut C>>
53    /// from which we can recover the original arguments by instantiating it with:
54    ///   <'a, 'a, u32, &'b A, bool>
55    ///
56    /// Formally, given `let (shape, shape_args) = get_mutability_shape(args);`, we have the following:
57    /// - `shape.substitute(shape_args) == args`;
58    /// - `shape_args` contains no infected types;
59    /// - `shape` is as shallow as possible (i.e. takes just enough to get all the infected types
60    ///     and not more).
61    ///
62    /// Note: the input arguments are assumed to have been already partially monomorphized, in the
63    /// sense that we won't recurse inside ADT args because we assume any ADT applied to infected
64    /// args to have been replaced with a fresh infected ADT.
65    fn compute_shape(
66        pm: &'pm PartialMonomorphizer<'ctx>,
67        target_params: &GenericParams,
68        args: &GenericArgs,
69    ) -> (MutabilityShape, GenericArgs) {
70        // We start with the implicit parameters from the original item. We'll need to substitute
71        // them once we've figured out the mapping of explicit parameters, but we'll also be adding
72        // new trait clauses potentially so we can't leave the vector empty (the ids would be
73        // wrong).
74        let mut shape_contents = args.clone();
75        let mut builder = Self {
76            pm,
77            params: GenericParams {
78                regions: Vector::new(),
79                types: Vector::new(),
80                const_generics: Vector::new(),
81                ..target_params.clone()
82            },
83            extracted: GenericArgs {
84                regions: Vector::new(),
85                types: Vector::new(),
86                const_generics: Vector::new(),
87                trait_refs: mem::take(&mut shape_contents.trait_refs),
88            },
89            binder_depth: DeBruijnId::zero(),
90        };
91
92        // Traverse the generics and replace any non-infected type, region or const generic with a
93        // fresh variable.
94        let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
95
96        let shape_params = {
97            let mut shape_params = builder.params;
98            // Now the explicit params in `shape_params` are correct, and the implicit params are a mix
99            // of the old params and new trait clauses. The old params may refer to the old explicit
100            // params which is wrong and must be fixed up.
101            shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
102                if i.index() < target_params.trait_clauses.slot_count() {
103                    x.substitute_explicits(&shape_contents)
104                } else {
105                    x
106                }
107            });
108            shape_params.trait_type_constraints =
109                shape_params.trait_type_constraints.map_indexed(|i, x| {
110                    if i.index() < target_params.trait_type_constraints.slot_count() {
111                        x.substitute_explicits(&shape_contents)
112                    } else {
113                        x
114                    }
115                });
116            shape_params.regions_outlive = shape_params
117                .regions_outlive
118                .into_iter()
119                .enumerate()
120                .map(|(i, x)| {
121                    if i < target_params.regions_outlive.len() {
122                        x.substitute_explicits(&shape_contents)
123                    } else {
124                        x
125                    }
126                })
127                .collect();
128            shape_params.types_outlive = shape_params
129                .types_outlive
130                .into_iter()
131                .enumerate()
132                .map(|(i, x)| {
133                    if i < target_params.types_outlive.len() {
134                        x.substitute_explicits(&shape_contents)
135                    } else {
136                        x
137                    }
138                })
139                .collect();
140            shape_params
141        };
142
143        // The first half of the trait params correspond to the original item clauses so we can
144        // pass them unmodified.
145        shape_contents.trait_refs = shape_params.identity_args().trait_refs;
146        shape_contents
147            .trait_refs
148            .truncate(target_params.trait_clauses.slot_count());
149
150        let shape_args = builder.extracted;
151        let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
152        (shape, shape_args)
153    }
154
155    /// Replace this value with a fresh variable, and record that we did so.
156    fn replace_with_fresh_var<Id, Param, Arg>(
157        &mut self,
158        val: &mut Arg,
159        mk_param: impl FnOnce(Id) -> Param,
160        mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
161    ) where
162        Id: Idx + Display,
163        Arg: TyVisitable + Clone,
164        GenericParams: HasVectorOf<Id, Output = Param>,
165        GenericArgs: HasVectorOf<Id, Output = Arg>,
166    {
167        let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
168            // Give up on this value.
169            return;
170        };
171        // Record the mapping in the output `GenericArgs`.
172        self.extracted.get_vector_mut().push(shifted_val);
173        // Put a fresh param in place of `val`.
174        let id = self.params.get_vector_mut().push_with(mk_param);
175        *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
176    }
177}
178
179impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
180    fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
181        &mut self.binder_depth
182    }
183}
184
185impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
186    fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
187        VisitWithBinderDepth::new(self).visit(x)
188    }
189
190    fn enter_ty(&mut self, ty: &mut Ty) {
191        if !self.pm.is_infected(ty) {
192            self.replace_with_fresh_var(
193                ty,
194                |id| TypeParam::new(id, format!("T{id}")),
195                |v| v.into(),
196            );
197        }
198    }
199    fn exit_ty_kind(&mut self, kind: &mut TyKind) {
200        if let TyKind::Adt(TypeDeclRef {
201            id: TypeId::Adt(id),
202            generics,
203        }) = kind
204        {
205            // Since the type was not replaced with a type var, it's an infected type. We've
206            // traversed it so we have its final explicit arguments. Now we need to satisfy its
207            // predicates. For that we add all its predicates to the new item, and pass those new
208            // trait clauses to it.
209            let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
210                return;
211            };
212            let Some(shifted_generics) =
213                generics.clone().move_from_under_binders(self.binder_depth)
214            else {
215                // Give up on this value.
216                return;
217            };
218
219            // Add the target predicates (properly substituted) to the new item params.
220            let num_clauses_before_merge = self.params.trait_clauses.slot_count();
221            self.params.merge_predicates_from(
222                target_params
223                    .clone()
224                    .substitute_explicits(&shifted_generics),
225            );
226
227            // Record the trait arguments in the output `GenericArgs`.
228            self.extracted
229                .trait_refs
230                .extend(shifted_generics.trait_refs);
231
232            // Replace each trait ref with a clause var.
233            for (target_clause_id, tref) in generics.trait_refs.iter_mut_indexed() {
234                let clause_id = target_clause_id + num_clauses_before_merge;
235                *tref =
236                    self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
237            }
238        }
239    }
240    fn enter_region(&mut self, r: &mut Region) {
241        self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
242    }
243    // TODO: we're missing type info for this
244    // fn enter_const_generic(&mut self, cg: &mut ConstGeneric) {
245    //     self.replace_with_fresh_var(cg, |id| {
246    //         ConstGenericParam::new(id, format!("N{id}"), cg.ty().clone())
247    //     });
248    // }
249    fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
250        // We don't touch trait refs or we'd risk adding duplicated extra params. Instead, we fix
251        // them up in `exit_ty_kind` and `compute_shape`.
252        ControlFlow::Continue(())
253    }
254}
255
256#[derive(Visitor)]
257struct PartialMonomorphizer<'a> {
258    ctx: &'a mut TransformCtx,
259    /// Tracks the closest span to emit useful errors.
260    span: Span,
261    /// Whether we partially-monomorphize type declarations.
262    instantiate_types: bool,
263    /// Types that contain mutable references.
264    infected_types: HashSet<TypeDeclId>,
265    /// Map of generic params for each item. We can't use `ctx.translated` because while iterating
266    /// over items the current item isn't available anymore, which would break recursive types.
267    /// This also makes it possible to record the generics of our to-be-added items without adding
268    /// them.
269    generic_params: HashMap<ItemId, GenericParams>,
270    /// Map of partial monomorphizations. The source item applied with the generic params gives the
271    /// target item. The resulting partially-monomorphized item will have the binder params as
272    /// generic params.
273    partial_mono_shapes: IndexMap<(ItemId, MutabilityShape), ItemId>,
274    /// Reverse of `partial_mono_shapes`.
275    reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
276    /// Items that need to be processed.
277    to_process: VecDeque<ItemId>,
278}
279
280impl<'a> PartialMonomorphizer<'a> {
281    pub fn new(ctx: &'a mut TransformCtx, instantiate_types: bool) -> Self {
282        use petgraph::graphmap::DiGraphMap;
283        use petgraph::visit::Dfs;
284        use petgraph::visit::Walker;
285
286        // Compute the types that contain `&mut` (even indirectly).
287        let infected_types: HashSet<_> = {
288            // We build a graph that has one node per type decl plus a special `None` node. If type A
289            // contains a reference to type B we add a B->A edge; if it contains a mutable reference we
290            // add a None->A edge. Then a type contains `&mut` iff it is reachable from the `None`
291            // node.
292            let mut graph: DiGraphMap<Option<TypeDeclId>, ()> = Default::default();
293            for (id, tdecl) in ctx.translated.type_decls.iter_indexed() {
294                tdecl.dyn_visit(|x: &Ty| match x.kind() {
295                    TyKind::Ref(_, _, RefKind::Mut) => {
296                        graph.add_edge(None, Some(id), ());
297                    }
298                    TyKind::Adt(tref) if let TypeId::Adt(other_id) = tref.id => {
299                        graph.add_edge(Some(other_id), Some(id), ());
300                    }
301                    _ => {}
302                });
303            }
304            let start = graph.add_node(None);
305            Dfs::new(&graph, start)
306                .iter(&graph)
307                .filter_map(|opt_id| opt_id)
308                .collect()
309        };
310
311        // Record the generic params of all items.
312        let generic_params: HashMap<ItemId, GenericParams> = ctx
313            .translated
314            .all_items()
315            .map(|item| (item.id(), item.generic_params().clone()))
316            .collect();
317
318        // Enqueue all items to be processed.
319        let to_process = ctx.translated.all_ids().collect();
320        PartialMonomorphizer {
321            ctx,
322            span: Span::dummy(),
323            instantiate_types,
324            infected_types,
325            generic_params,
326            to_process,
327            partial_mono_shapes: IndexMap::default(),
328            reverse_shape_map: Default::default(),
329        }
330    }
331
332    /// Whether this type is or contains a `&mut`. This assumes that we've already visited this
333    /// type and partially monomorphized any ADT references.
334    fn is_infected(&self, ty: &Ty) -> bool {
335        match ty.kind() {
336            TyKind::Ref(_, _, RefKind::Mut) => true,
337            TyKind::Ref(_, ty, _) | TyKind::RawPtr(ty, _) => self.is_infected(ty),
338            TyKind::Adt(tref) => {
339                let ty_infected =
340                    matches!(&tref.id, TypeId::Adt(id) if self.infected_types.contains(id));
341                let args_infected = if tref.id.is_adt() && self.instantiate_types {
342                    // Since we make sure to only call the method on a processed type, any type
343                    // with infected arguments would have been replaced with a fresh instantiated
344                    // (and infected type). Hence we don't need to check the arguments here, only
345                    // the type id.
346                    false
347                } else {
348                    tref.generics.types.iter().any(|ty| self.is_infected(ty))
349                };
350                ty_infected || args_infected
351            }
352            TyKind::FnDef(..) | TyKind::FnPtr(..) => {
353                register_error!(
354                    self.ctx,
355                    self.span,
356                    "function pointers are unsupported with `--monomorphize-mut`"
357                );
358                false
359            }
360            TyKind::DynTrait(_) => {
361                register_error!(
362                    self.ctx,
363                    self.span,
364                    "`dyn Trait` is unsupported with `--monomorphize-mut`"
365                );
366                false
367            }
368            TyKind::TypeVar(..)
369            | TyKind::Literal(..)
370            | TyKind::Never
371            | TyKind::TraitType(..)
372            | TyKind::PtrMetadata(..)
373            | TyKind::Error(_) => false,
374        }
375    }
376
377    /// Given that `generics` apply to item `id`, if any of the generics is infected we generate a
378    /// reference to a new item obtained by partially instantiating item `id`. (That new item isn't
379    /// added immediately but is added to the `to_process` queue to be created later).
380    fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
381        if !generics.types.iter().any(|ty| self.is_infected(ty)) {
382            return None;
383        }
384
385        // If the type is already an instantiation, transform this reference into a reference to
386        // the original type so we don't instantiate the instantiation.
387        let mut new_generics;
388        let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
389            new_generics = shape.clone().apply(generics);
390            let _ = self.visit(&mut new_generics); // New instantiation may require cleanup.
391            (base_id, &new_generics)
392        } else {
393            (id, generics)
394        };
395
396        // Split the args between the infected part and the non-infected part.
397        let item_params = self.generic_params.get(&id)?;
398        let (shape, shape_args) =
399            MutabilityShapeBuilder::compute_shape(self, item_params, generics);
400
401        // Create a new type id.
402        let new_params = shape.params.clone();
403        let key: (ItemId, MutabilityShape) = (id, shape);
404        let new_id = *self
405            .partial_mono_shapes
406            .entry(key.clone())
407            .or_insert_with(|| {
408                let new_id = match id {
409                    ItemId::Type(_) => {
410                        let new_id = self.ctx.translated.type_decls.reserve_slot();
411                        self.infected_types.insert(new_id);
412                        new_id.into()
413                    }
414                    ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
415                    ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
416                    ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
417                    ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
418                };
419                self.generic_params.insert(new_id, new_params);
420                self.reverse_shape_map.insert(new_id, key);
421                self.to_process.push_back(new_id);
422                new_id
423            });
424
425        let fmt_ctx = self.ctx.into_fmt();
426        trace!(
427            "processing {}{}\n output: {}{}",
428            id.with_ctx(&fmt_ctx),
429            generics.with_ctx(&fmt_ctx),
430            new_id.with_ctx(&fmt_ctx),
431            shape_args.with_ctx(&fmt_ctx),
432        );
433        Some(DeclRef {
434            id: new_id,
435            generics: Box::new(shape_args),
436        })
437    }
438
439    /// Traverse the item, replacing any type instantiations we don't want with references to
440    /// soon-to-be-created partially-monomorphized types. This does not access the items in
441    /// `self.translated`, which may be missing since we took `item` out for processing.
442    pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
443        let _ = item.drive_mut(self);
444    }
445
446    /// Creates the item corresponding to this id by instantiating the item it is based on.
447    ///
448    /// This accesses the items in `self.translated`, which must therefore all be there.
449    /// That's why items are created outside of `process_item`.
450    pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
451        let (orig_id, shape) = &self.reverse_shape_map[&new_id];
452        let mut decl = self
453            .ctx
454            .translated
455            .get_item(*orig_id)
456            .unwrap()
457            .clone()
458            .substitute(&shape.skip_binder);
459
460        let mut decl_mut = decl.as_mut();
461        decl_mut.set_id(new_id);
462        *decl_mut.generic_params() = shape.params.clone();
463
464        let name_ref = &mut decl_mut.item_meta().name;
465        *name_ref = mem::take(name_ref).instantiate(shape.clone());
466        self.ctx
467            .translated
468            .item_names
469            .insert(new_id, decl.as_ref().item_meta().name.clone());
470
471        decl
472    }
473}
474
475impl VisitorWithSpan for PartialMonomorphizer<'_> {
476    fn current_span(&mut self) -> &mut Span {
477        &mut self.span
478    }
479}
480impl VisitAstMut for PartialMonomorphizer<'_> {
481    fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
482        // Track a useful enclosing span, for error messages.
483        VisitWithSpan::new(self).visit(x)
484    }
485
486    fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
487        if self.instantiate_types
488            && let TypeId::Adt(id) = x.id
489            && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
490        {
491            *x = new_decl_ref.try_into().unwrap()
492        }
493    }
494    fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
495        // TODO: methods. any `Trait::method<&mut A>` requires monomorphizing all the instances of
496        // that method just in case :>>>
497        if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
498            && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
499        {
500            *x = new_decl_ref.try_into().unwrap()
501        }
502    }
503    fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
504        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
505            *x = new_decl_ref.try_into().unwrap()
506        }
507    }
508    fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
509        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
510            *x = new_decl_ref.try_into().unwrap()
511        }
512    }
513    fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
514        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
515            *x = new_decl_ref.try_into().unwrap()
516        }
517    }
518    fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
519        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
520            *x = new_decl_ref.try_into().unwrap()
521        }
522    }
523}
524
525pub struct Transform;
526impl TransformPass for Transform {
527    fn transform_ctx(&self, ctx: &mut TransformCtx) {
528        let Some(include_types) = ctx.options.monomorphize_mut else {
529            return;
530        };
531        // TODO: test name matcher, also with methods
532        let mut visitor =
533            PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
534        while let Some(id) = visitor.to_process.pop_front() {
535            // Get the item corresponding to this id, either by creating it or by getting an
536            // existing one.
537            let mut decl = if visitor.reverse_shape_map.get(&id).is_some() {
538                // Create the required item by instantiating the item it's based on.
539                visitor.create_pending_instantiation(id)
540            } else {
541                // Take the item out so we can modify it. Warning: don't look up other items in the
542                // meantime as this would break in recursive cases.
543                match visitor.ctx.translated.remove_item(id) {
544                    Some(decl) => decl,
545                    None => continue,
546                }
547            };
548            // Visit the item, replacing type instantiations with references to soon-to-be-created
549            // partially-monomorphized types.
550            visitor.process_item(&mut decl.as_mut());
551            // Put the item back.
552            visitor.ctx.translated.set_item_slot(id, decl);
553        }
554    }
555}