Skip to main content

charon_lib/transform/normalize/
partial_monomorphization.rs

1//! This module implements partial monomorphization, which allows specializing generic items on
2//! some specific instanciation patterns. This is used by Aeneas to avoid nested mutable borrows:
3//! we transform `Iter<'a, &'b mut T>` to `{Iter::<_, &mut U>}<'a, 'b, T>`, where
4//! ```ignore
5//! struct {Iter::<'a, &'b mut U>}<'a, 'b, U> {
6//!   // the field of `Iter` but instantiated with `T -> &'b mut U`.
7//! }
8//! ```
9//!
10//! Note: We may need to partial-mono the same item multiple times: `Foo::<&mut A, B>`, `Foo::<A,
11//! &mut B>`. Note also that partial-mono is infectious: `Foo<Bar<&mut A>>` generates `Bar::<&mut
12//! A>` then `Foo::<Bar::<&mut A>>``.
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19
20use crate::ast::types_utils::TyVisitable;
21use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
22use crate::formatter::IntoFormatter;
23use crate::ids::IndexVec;
24use crate::options::MonomorphizeMut;
25use crate::pretty::FmtWithCtx;
26use crate::register_error;
27use crate::transform::ctx::TransformPass;
28use crate::{transform::TransformCtx, ullbc_ast::*};
29
30type MutabilityShape = Binder<GenericArgs>;
31
32/// See the docs of `MutabilityShapeBuilder::compute_shape`.
33#[derive(Visitor)]
34struct MutabilityShapeBuilder<'pm, 'ctx> {
35    pm: &'pm PartialMonomorphizer<'ctx>,
36    /// The parameters that will constitute the final binder.
37    params: GenericParams,
38    /// The arguments to pass to the final binder to recover the input arguments.
39    extracted: GenericArgs,
40    /// Current depth under which we're visiting.
41    binder_depth: DeBruijnId,
42}
43
44impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
45    /// Compute the mutability "shape" of a set of generic arguments by factoring out the minimal
46    /// amount of information that still allows reconstructing the original arguments while keeping the
47    /// "shape arguments" free of mutable borrows.
48    ///
49    /// For example, for input:
50    ///   <u32, &'a mut &'b A, Option::<&'a mut bool>>
51    /// we want to build:
52    ///   binder<'a, 'b, A, B, C> <A, &'a mut B, Option::<&'b mut C>>
53    /// from which we can recover the original arguments by instantiating it with:
54    ///   <'a, 'a, u32, &'b A, bool>
55    ///
56    /// Formally, given `let (shape, shape_args) = get_mutability_shape(args);`, we have the following:
57    /// - `shape.substitute(shape_args) == args`;
58    /// - `shape_args` contains no infected types;
59    /// - `shape` is as shallow as possible (i.e. takes just enough to get all the infected types
60    ///   and not more).
61    ///
62    /// Note: the input arguments are assumed to have been already partially monomorphized, in the
63    /// sense that we won't recurse inside ADT args because we assume any ADT applied to infected
64    /// args to have been replaced with a fresh infected ADT.
65    fn compute_shape(
66        pm: &'pm PartialMonomorphizer<'ctx>,
67        target_params: &GenericParams,
68        args: &GenericArgs,
69    ) -> (MutabilityShape, GenericArgs) {
70        // We start with the implicit parameters from the original item. We'll need to substitute
71        // them once we've figured out the mapping of explicit parameters, but we'll also be adding
72        // new trait clauses potentially so we can't leave the vector empty (the ids would be
73        // wrong).
74        let mut shape_contents = args.clone();
75        let mut builder = Self {
76            pm,
77            params: GenericParams {
78                regions: IndexVec::new(),
79                types: IndexVec::new(),
80                const_generics: IndexVec::new(),
81                ..target_params.clone()
82            },
83            extracted: GenericArgs {
84                regions: IndexVec::new(),
85                types: IndexVec::new(),
86                const_generics: IndexVec::new(),
87                trait_refs: mem::take(&mut shape_contents.trait_refs),
88            },
89            binder_depth: DeBruijnId::zero(),
90        };
91
92        // Traverse the generics and replace any non-infected type, region or const generic with a
93        // fresh variable.
94        let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
95
96        let shape_params = {
97            let mut shape_params = builder.params;
98            // Now the explicit params in `shape_params` are correct, and the implicit params are a mix
99            // of the old params and new trait clauses. The old params may refer to the old explicit
100            // params which is wrong and must be fixed up.
101            shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
102                if i.index() < target_params.trait_clauses.len() {
103                    x.substitute_explicits(&shape_contents)
104                } else {
105                    x
106                }
107            });
108            shape_params.trait_type_constraints =
109                shape_params.trait_type_constraints.map_indexed(|i, x| {
110                    if i.index() < target_params.trait_type_constraints.len() {
111                        x.substitute_explicits(&shape_contents)
112                    } else {
113                        x
114                    }
115                });
116            shape_params.regions_outlive = shape_params
117                .regions_outlive
118                .into_iter()
119                .enumerate()
120                .map(|(i, x)| {
121                    if i < target_params.regions_outlive.len() {
122                        x.substitute_explicits(&shape_contents)
123                    } else {
124                        x
125                    }
126                })
127                .collect();
128            shape_params.types_outlive = shape_params
129                .types_outlive
130                .into_iter()
131                .enumerate()
132                .map(|(i, x)| {
133                    if i < target_params.types_outlive.len() {
134                        x.substitute_explicits(&shape_contents)
135                    } else {
136                        x
137                    }
138                })
139                .collect();
140            shape_params
141        };
142
143        // The first half of the trait params correspond to the original item clauses so we can
144        // pass them unmodified.
145        shape_contents.trait_refs = shape_params.identity_args().trait_refs;
146        shape_contents
147            .trait_refs
148            .truncate(target_params.trait_clauses.len());
149
150        let shape_args = builder.extracted;
151        let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
152        (shape, shape_args)
153    }
154
155    /// Replace this value with a fresh variable, and record that we did so.
156    fn replace_with_fresh_var<Id, Param, Arg>(
157        &mut self,
158        val: &mut Arg,
159        mk_param: impl FnOnce(Id) -> Param,
160        mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
161    ) where
162        Id: Idx + Display,
163        Arg: TyVisitable + Clone,
164        GenericParams: HasIdxVecOf<Id, Output = Param>,
165        GenericArgs: HasIdxVecOf<Id, Output = Arg>,
166    {
167        let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
168            // Give up on this value.
169            return;
170        };
171        // Record the mapping in the output `GenericArgs`.
172        self.extracted.get_idx_vec_mut().push(shifted_val);
173        // Put a fresh param in place of `val`.
174        let id = self.params.get_idx_vec_mut().push_with(mk_param);
175        *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
176    }
177}
178
179impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
180    fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
181        &mut self.binder_depth
182    }
183}
184
185impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
186    fn visit<T: AstVisitable>(&mut self, x: &mut T) -> ControlFlow<Self::Break> {
187        VisitWithBinderDepth::new(self).visit(x)
188    }
189
190    fn enter_ty(&mut self, ty: &mut Ty) {
191        if !self.pm.is_infected(ty) {
192            self.replace_with_fresh_var(
193                ty,
194                |id| TypeParam::new(id, format!("T{id}")),
195                |v| v.into(),
196            );
197        }
198    }
199    fn exit_ty_kind(&mut self, kind: &mut TyKind) {
200        if let TyKind::Adt(TypeDeclRef {
201            id: TypeId::Adt(id),
202            generics,
203        }) = kind
204        {
205            // Since the type was not replaced with a type var, it's an infected type. We've
206            // traversed it so we have its final explicit arguments. Now we need to satisfy its
207            // predicates. For that we add all its predicates to the new item, and pass those new
208            // trait clauses to it.
209            let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
210                return;
211            };
212            let Some(shifted_generics) =
213                generics.clone().move_from_under_binders(self.binder_depth)
214            else {
215                // Give up on this value.
216                return;
217            };
218
219            // Add the target predicates (properly substituted) to the new item params.
220            let num_clauses_before_merge = self.params.trait_clauses.len();
221            self.params.merge_predicates_from(
222                target_params
223                    .clone()
224                    .substitute_explicits(&shifted_generics),
225            );
226
227            // Record the trait arguments in the output `GenericArgs`.
228            self.extracted
229                .trait_refs
230                .extend(shifted_generics.trait_refs);
231
232            // Replace each trait ref with a clause var.
233            for (target_clause_id, tref) in generics.trait_refs.iter_mut_enumerated() {
234                let clause_id = target_clause_id + num_clauses_before_merge;
235                *tref =
236                    self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
237            }
238        }
239    }
240    fn enter_region(&mut self, r: &mut Region) {
241        self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
242    }
243    // TODO: we're missing type info for this
244    // fn enter_const_generic(&mut self, cg: &mut ConstGeneric) {
245    //     self.replace_with_fresh_var(cg, |id| {
246    //         ConstGenericParam::new(id, format!("N{id}"), cg.ty().clone())
247    //     });
248    // }
249    fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
250        // We don't touch trait refs or we'd risk adding duplicated extra params. Instead, we fix
251        // them up in `exit_ty_kind` and `compute_shape`.
252        ControlFlow::Continue(())
253    }
254
255    fn visit_constant_expr(
256        &mut self,
257        _: &mut ConstantExpr,
258    ) -> ::std::ops::ControlFlow<Self::Break> {
259        ControlFlow::Continue(())
260    }
261}
262
263#[derive(Visitor)]
264struct PartialMonomorphizer<'a> {
265    ctx: &'a mut TransformCtx,
266    /// Tracks the closest span to emit useful errors.
267    span: Span,
268    /// Whether we partially-monomorphize type declarations.
269    specialize_adts: bool,
270    /// Types that contain mutable references.
271    infected_types: HashSet<TypeDeclId>,
272    /// Map of generic params for each item. We can't use `ctx.translated` because while iterating
273    /// over items the current item isn't available anymore, which would break recursive types.
274    /// This also makes it possible to record the generics of our to-be-added items without adding
275    /// them.
276    generic_params: HashMap<ItemId, GenericParams>,
277    /// Map of partial monomorphizations. The source item applied with the generic params gives the
278    /// target item. The resulting partially-monomorphized item will have the binder params as
279    /// generic params.
280    partial_mono_shapes: SeqHashMap<(ItemId, MutabilityShape), ItemId>,
281    /// Reverse of `partial_mono_shapes`.
282    reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
283    /// Items that need to be processed.
284    to_process: VecDeque<ItemId>,
285}
286
287impl<'a> PartialMonomorphizer<'a> {
288    pub fn new(ctx: &'a mut TransformCtx, specialize_adts: bool) -> Self {
289        // Compute the types that contain `&mut` (even indirectly). We actually can ignore
290        // `&'static mut`, so we simply rely on our "lifetime mutability" computation.
291        let infected_types: HashSet<_> = ctx
292            .translated
293            .type_decls
294            .iter()
295            .filter(|tdecl| {
296                tdecl
297                    .generics
298                    .regions
299                    .iter()
300                    .any(|r| r.mutability.is_mutable())
301            })
302            .map(|tdecl| tdecl.def_id)
303            .collect();
304
305        // Record the generic params of all items.
306        let generic_params: HashMap<ItemId, GenericParams> = ctx
307            .translated
308            .all_items()
309            .map(|item| (item.id(), item.generic_params().clone()))
310            .collect();
311
312        // Enqueue all items to be processed.
313        let to_process = ctx.translated.all_ids().collect();
314        PartialMonomorphizer {
315            ctx,
316            span: Span::dummy(),
317            specialize_adts,
318            infected_types,
319            generic_params,
320            to_process,
321            partial_mono_shapes: SeqHashMap::default(),
322            reverse_shape_map: Default::default(),
323        }
324    }
325
326    /// Whether this type is or contains a `&mut`. This assumes that we've already visited this
327    /// type and partially monomorphized any ADT references.
328    fn is_infected(&self, ty: &Ty) -> bool {
329        match ty.kind() {
330            TyKind::Ref(_, _, RefKind::Mut) => true,
331            TyKind::Ref(_, ty, _)
332            | TyKind::RawPtr(ty, _)
333            | TyKind::Array(ty, _)
334            | TyKind::Pattern(ty, _)
335            | TyKind::Slice(ty) => self.is_infected(ty),
336            TyKind::Adt(tref) => match tref.id {
337                TypeId::Adt(id) => {
338                    let ty_infected = self.infected_types.contains(&id);
339                    let args_infected = if self.specialize_adts {
340                        // Since we make sure to only call the method on a processed type, any type
341                        // with infected arguments would have been replaced with a fresh instantiated
342                        // (and infected type). Hence we don't need to check the arguments here, only
343                        // the type id.
344                        false
345                    } else {
346                        tref.generics.types.iter().any(|ty| self.is_infected(ty))
347                    };
348                    ty_infected || args_infected
349                }
350                TypeId::Tuple | TypeId::Builtin(_) => {
351                    // Builtin types have no declaration to specialize, so infected arguments stay
352                    // visible inside them.
353                    tref.generics.types.iter().any(|ty| self.is_infected(ty))
354                }
355            },
356            // A function pointer/item by itself doesn't carry any mutable reference, even if it
357            // uses some in its signature. Compare with closures: a closure without captures
358            // doesn't trigger partial mono regardless of its signature.
359            TyKind::FnDef(..) | TyKind::FnPtr(..) => false,
360            TyKind::DynTrait(_) => {
361                register_error!(
362                    self.ctx,
363                    self.span,
364                    "`dyn Trait` is unsupported with `--monomorphize-mut`"
365                );
366                false
367            }
368            TyKind::TypeVar(..)
369            | TyKind::Literal(..)
370            | TyKind::Never
371            | TyKind::TraitType(..)
372            | TyKind::PtrMetadata(..)
373            | TyKind::Error(_) => false,
374        }
375    }
376
377    /// Given that `generics` apply to item `id`, if any of the generics is infected we generate a
378    /// reference to a new item obtained by partially instantiating item `id`. (That new item isn't
379    /// added immediately but is added to the `to_process` queue to be created later).
380    fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
381        if !generics.types.iter().any(|ty| self.is_infected(ty)) {
382            return None;
383        }
384
385        // If the type is already an instantiation, transform this reference into a reference to
386        // the original type so we don't instantiate the instantiation.
387        let mut new_generics;
388        let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
389            new_generics = shape.clone().apply(generics);
390            let _ = self.visit(&mut new_generics); // New instantiation may require cleanup.
391            (base_id, &new_generics)
392        } else {
393            (id, generics)
394        };
395
396        // Split the args between the infected part and the non-infected part.
397        let item_params = self.generic_params.get(&id)?;
398        let (shape, shape_args) =
399            MutabilityShapeBuilder::compute_shape(self, item_params, generics);
400
401        // Create a new type id.
402        let new_params = shape.params.clone();
403        let key: (ItemId, MutabilityShape) = (id, shape);
404        let new_id = *self
405            .partial_mono_shapes
406            .entry(key.clone())
407            .or_insert_with(|| {
408                let new_id = match id {
409                    ItemId::Type(_) => {
410                        let new_id = self.ctx.translated.type_decls.reserve_slot();
411                        self.infected_types.insert(new_id);
412                        new_id.into()
413                    }
414                    ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
415                    ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
416                    ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
417                    ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
418                };
419                self.generic_params.insert(new_id, new_params);
420                self.reverse_shape_map.insert(new_id, key);
421                self.to_process.push_back(new_id);
422                new_id
423            });
424
425        let fmt_ctx = self.ctx.into_fmt();
426        trace!(
427            "processing {}{}\n output: {}{}",
428            id.with_ctx(&fmt_ctx),
429            generics.with_ctx(&fmt_ctx),
430            new_id.with_ctx(&fmt_ctx),
431            shape_args.with_ctx(&fmt_ctx),
432        );
433        Some(DeclRef {
434            id: new_id,
435            generics: Box::new(shape_args),
436            trait_ref: None,
437        })
438    }
439
440    /// Traverse the item, replacing any type instantiations we don't want with references to
441    /// soon-to-be-created partially-monomorphized types. This does not access the items in
442    /// `self.translated`, which may be missing since we took `item` out for processing.
443    pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
444        let _ = item.drive_mut(self);
445    }
446
447    /// Creates the item corresponding to this id by instantiating the item it is based on.
448    ///
449    /// This accesses the items in `self.translated`, which must therefore all be there.
450    /// That's why items are created outside of `process_item`.
451    pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
452        let (orig_id, shape) = &self.reverse_shape_map[&new_id];
453        let mut decl = self
454            .ctx
455            .translated
456            .get_item(*orig_id)
457            .unwrap()
458            .to_owned()
459            .substitute_with_self(&shape.skip_binder, &TraitRefKind::SelfId);
460
461        let mut decl_mut = decl.as_mut();
462        decl_mut.set_id(new_id);
463        *decl_mut.generic_params() = shape.params.clone();
464
465        let name_ref = &mut decl_mut.item_meta().name;
466        *name_ref = mem::take::<crate::ast::Name>(name_ref).instantiate(shape.clone());
467        self.ctx
468            .translated
469            .item_names
470            .insert(new_id, decl.as_ref().item_meta().name.clone());
471        if let (ItemId::TraitDecl(orig_trait_id), ItemId::TraitDecl(new_trait_id)) =
472            (*orig_id, new_id)
473        {
474            let names = self.ctx.translated.assoc_item_names[orig_trait_id].clone();
475            self.ctx
476                .translated
477                .assoc_item_names
478                .insert(new_trait_id, names);
479        }
480
481        decl
482    }
483}
484
485impl VisitorWithSpan for PartialMonomorphizer<'_> {
486    fn current_span(&mut self) -> &mut Span {
487        &mut self.span
488    }
489}
490impl VisitAstMut for PartialMonomorphizer<'_> {
491    fn visit<T: AstVisitable>(&mut self, x: &mut T) -> ControlFlow<Self::Break> {
492        // Track a useful enclosing span, for error messages.
493        VisitWithSpan::new(self).visit(x)
494    }
495
496    fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
497        if self.specialize_adts
498            && let TypeId::Adt(id) = x.id
499            && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
500        {
501            *x = new_decl_ref.try_into().unwrap()
502        }
503    }
504    fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
505        // TODO: methods. any `Trait::method<&mut A>` requires monomorphizing all the instances of
506        // that method just in case :>>>
507        if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
508            && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
509        {
510            *x = new_decl_ref.try_into().unwrap()
511        }
512    }
513    fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
514        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
515            *x = new_decl_ref.try_into().unwrap()
516        }
517    }
518    fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
519        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
520            *x = new_decl_ref.try_into().unwrap()
521        }
522    }
523    fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
524        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
525            *x = new_decl_ref.try_into().unwrap()
526        }
527    }
528    fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
529        if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
530            *x = new_decl_ref.try_into().unwrap()
531        }
532    }
533}
534
535pub struct Transform;
536impl TransformPass for Transform {
537    fn transform_ctx(&self, ctx: &mut TransformCtx) {
538        let Some(include_types) = ctx.options.monomorphize_mut else {
539            return;
540        };
541        // TODO: test name matcher, also with methods
542        let mut visitor =
543            PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
544        while let Some(id) = visitor.to_process.pop_front() {
545            // Get the item corresponding to this id, either by creating it or by getting an
546            // existing one.
547            let mut decl = if visitor.reverse_shape_map.contains_key(&id) {
548                // Create the required item by instantiating the item it's based on.
549                visitor.create_pending_instantiation(id)
550            } else {
551                // Take the item out so we can modify it. Warning: don't look up other items in the
552                // meantime as this would break in recursive cases.
553                match visitor.ctx.translated.remove_item_temporarily(id) {
554                    Some(decl) => decl,
555                    None => continue,
556                }
557            };
558            // Visit the item, replacing type instantiations with references to soon-to-be-created
559            // partially-monomorphized types.
560            visitor.process_item(&mut decl.as_mut());
561            // Put the item back.
562            visitor.ctx.translated.set_item_slot(id, decl);
563        }
564    }
565}