charon_lib/transform/
monomorphize.rs

1//! # Micro-pass: monomorphize all functions and types; at the end of this pass, all functions and types are monomorphic.
2use derive_generic_visitor::*;
3use std::collections::{HashMap, HashSet};
4
5use crate::ast::*;
6use crate::transform::TransformCtx;
7use std::fmt::Debug;
8
9use super::ctx::TransformPass;
10
11enum OptionHint<T, H> {
12    Some(T),
13    None,
14    Hint(H),
15}
16
17impl<T, H> OptionHint<T, H> {
18    fn is_some(&self) -> bool {
19        match self {
20            OptionHint::Some(_) => true,
21            OptionHint::None => false,
22            OptionHint::Hint(_) => false,
23        }
24    }
25
26    fn hint_or<'a>(&'a self, hint: &'a H) -> &'a H {
27        match self {
28            OptionHint::Some(_) => hint,
29            OptionHint::None => hint,
30            OptionHint::Hint(h) => h,
31        }
32    }
33}
34
35struct PassData {
36    // Map of (poly item, generic args) -> mono item
37    // None indicates the item hasn't been monomorphized yet
38    items: HashMap<(AnyTransId, GenericArgs), OptionHint<AnyTransId, (AnyTransId, BoxedArgs)>>,
39    worklist: Vec<AnyTransId>,
40    visited: HashSet<AnyTransId>,
41}
42
43impl PassData {
44    fn new() -> Self {
45        PassData {
46            items: HashMap::new(),
47            worklist: Vec::new(),
48            visited: HashSet::new(),
49        }
50    }
51}
52
53impl TranslatedCrate {
54    // FIXME(Nadrieril): implement type&tref normalization and use that instead
55    fn find_trait_impl_and_gargs(
56        self: &Self,
57        kind: &TraitRefKind,
58    ) -> Option<(&TraitImpl, GenericArgs)> {
59        match kind {
60            TraitRefKind::TraitImpl(impl_ref) => {
61                let trait_impl = self.trait_impls.get(impl_ref.id)?;
62                Some((trait_impl, impl_ref.generics.as_ref().clone()))
63            }
64            TraitRefKind::ParentClause(p, _, clause) => {
65                let (trait_impl, _) = self.find_trait_impl_and_gargs(p)?;
66                let t_ref = trait_impl.parent_trait_refs.get(*clause)?;
67                self.find_trait_impl_and_gargs(&t_ref.kind)
68            }
69            _ => None,
70        }
71    }
72}
73
74#[derive(Visitor)]
75struct UsageVisitor<'a> {
76    data: &'a mut PassData,
77    krate: &'a TranslatedCrate,
78}
79impl UsageVisitor<'_> {
80    fn found_use(
81        &mut self,
82        id: &AnyTransId,
83        gargs: &GenericArgs,
84        default: OptionHint<AnyTransId, (AnyTransId, BoxedArgs)>,
85    ) {
86        trace!("Mono: Found use: {:?} / {:?}", id, gargs);
87        self.data
88            .items
89            .entry((*id, gargs.clone()))
90            .or_insert(default);
91    }
92    fn found_use_ty(&mut self, tref: &TypeDeclRef) {
93        match tref.id {
94            TypeId::Adt(id) => {
95                self.found_use(&AnyTransId::Type(id), &tref.generics, OptionHint::None)
96            }
97            _ => {}
98        }
99    }
100    fn found_use_fn(&mut self, id: &FunDeclId, gargs: &GenericArgs) {
101        self.found_use(&AnyTransId::Fun(*id), gargs, OptionHint::None);
102    }
103    fn found_use_global_decl_ref(&mut self, id: &GlobalDeclId, gargs: &GenericArgs) {
104        self.found_use(&AnyTransId::Global(*id), gargs, OptionHint::None);
105    }
106    fn found_use_fn_hinted(
107        &mut self,
108        id: &FunDeclId,
109        gargs: &GenericArgs,
110        (h_id, h_args): (FunDeclId, BoxedArgs),
111    ) {
112        self.found_use(
113            &AnyTransId::Fun(*id),
114            gargs,
115            OptionHint::Hint((AnyTransId::Fun(h_id), h_args)),
116        );
117    }
118}
119impl VisitAst for UsageVisitor<'_> {
120    // we need to skip ItemMeta, as we don't want to collect the types in PathElem::Impl
121    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Infallible> {
122        Continue(())
123    }
124
125    fn enter_aggregate_kind(&mut self, kind: &AggregateKind) {
126        match kind {
127            AggregateKind::Adt(tref, _, _) => self.found_use_ty(tref),
128            _ => {}
129        }
130    }
131
132    fn enter_ty_kind(&mut self, kind: &TyKind) {
133        match kind {
134            TyKind::Adt(tref) => {
135                self.found_use_ty(tref);
136            }
137            _ => {}
138        }
139    }
140
141    fn enter_fn_ptr(&mut self, fn_ptr: &FnPtr) {
142        match fn_ptr.func.as_ref() {
143            FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
144                self.found_use_fn(&id, &fn_ptr.generics)
145            }
146            FunIdOrTraitMethodRef::Trait(t_ref, name, id) => {
147                let Some((trait_impl, impl_gargs)) =
148                    self.krate.find_trait_impl_and_gargs(&t_ref.kind)
149                else {
150                    return;
151                };
152                let (_, bound_fn) = trait_impl.methods().find(|(n, _)| n == name).unwrap();
153                let fn_ref: Binder<Binder<FunDeclRef>> = Binder::new(
154                    BinderKind::Other,
155                    trait_impl.generics.clone(),
156                    bound_fn.clone(),
157                );
158                // This is the actual function we need to call!
159                // Whereas id is the trait method reference(?)
160                let fn_ref = fn_ref.apply(&impl_gargs).apply(&fn_ptr.generics);
161                let gargs_key = fn_ptr
162                    .generics
163                    .clone()
164                    .concat(&t_ref.trait_decl_ref.skip_binder.generics);
165                self.found_use_fn_hinted(&id, &gargs_key, (fn_ref.id, fn_ref.generics))
166            }
167            // These can't be monomorphized, since they're builtins
168            FunIdOrTraitMethodRef::Fun(FunId::Builtin(..)) => {}
169        }
170    }
171
172    fn enter_global_decl_ref(&mut self, glob: &GlobalDeclRef) {
173        self.found_use_global_decl_ref(&glob.id, &glob.generics);
174    }
175}
176
177// Akin to UsageVisitor, but substitutes all uses of generics with the monomorphized versions
178// This is a two-step process, because we can't mutate the translation context with new definitions
179// while also mutating the existing definitions.
180#[derive(Visitor)]
181struct SubstVisitor<'a> {
182    data: &'a PassData,
183}
184impl SubstVisitor<'_> {
185    fn subst_use<T, F>(&mut self, id: &mut T, gargs: &mut GenericArgs, of: F)
186    where
187        T: Into<AnyTransId> + Debug + Copy,
188        F: Fn(&AnyTransId) -> Option<&T>,
189    {
190        trace!("Mono: Subst use: {:?} / {:?}", id, gargs);
191        // Erase regions.
192        gargs.regions.iter_mut().for_each(|r| *r = Region::Erased);
193        let key = ((*id).into(), gargs.clone());
194        let subst = self.data.items.get(&key);
195        if let Some(OptionHint::Some(any_id)) = subst
196            && let Some(subst_id) = of(any_id)
197        {
198            *id = *subst_id;
199            *gargs = GenericArgs::empty();
200        } else {
201            warn!("Substitution missing for {:?} / {:?}", id, gargs);
202        }
203    }
204    fn subst_use_ty(&mut self, tref: &mut TypeDeclRef) {
205        match &mut tref.id {
206            TypeId::Adt(id) => {
207                self.subst_use(id, &mut tref.generics, AnyTransId::as_type);
208            }
209            _ => {}
210        }
211    }
212    fn subst_use_fun(&mut self, id: &mut FunDeclId, gargs: &mut GenericArgs) {
213        self.subst_use(id, gargs, AnyTransId::as_fun);
214    }
215    fn subst_use_glob(&mut self, id: &mut GlobalDeclId, gargs: &mut GenericArgs) {
216        self.subst_use(id, gargs, AnyTransId::as_global);
217    }
218}
219
220impl VisitAstMut for SubstVisitor<'_> {
221    fn exit_rvalue(&mut self, rval: &mut Rvalue) {
222        if let Rvalue::Discriminant(place, id) = rval
223            && let Some(tref) = place.ty.as_adt()
224            && let TypeId::Adt(new_enum_id) = tref.id
225        {
226            // Small trick; the discriminant doesn't carry the information on the
227            // generics of the enum, since it's irrelevant, but we need it to do
228            // the substitution, so we look at the type of the place we read from
229            *id = new_enum_id;
230        }
231    }
232
233    fn enter_aggregate_kind(&mut self, kind: &mut AggregateKind) {
234        match kind {
235            AggregateKind::Adt(tref, _, _) => self.subst_use_ty(tref),
236            _ => {}
237        }
238    }
239
240    fn enter_ty_kind(&mut self, kind: &mut TyKind) {
241        match kind {
242            TyKind::Adt(tref) => self.subst_use_ty(tref),
243            _ => {}
244        }
245    }
246
247    fn enter_fn_ptr(&mut self, fn_ptr: &mut FnPtr) {
248        match fn_ptr.func.as_mut() {
249            FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id)) => {
250                self.subst_use_fun(fun_id, &mut fn_ptr.generics)
251            }
252            FunIdOrTraitMethodRef::Trait(t_ref, _, fun_id) => {
253                let mut gargs_key = fn_ptr
254                    .generics
255                    .clone()
256                    .concat(&t_ref.trait_decl_ref.skip_binder.generics);
257                self.subst_use_fun(fun_id, &mut gargs_key);
258                fn_ptr.generics = Box::new(gargs_key);
259            }
260            // These can't be monomorphized, since they're builtins
261            FunIdOrTraitMethodRef::Fun(FunId::Builtin(..)) => {}
262        }
263    }
264
265    fn exit_place(&mut self, place: &mut Place) {
266        match &mut place.kind {
267            // FIXME(Nadrieril): remove this id, replace with a helper fn
268            PlaceKind::Projection(inner, ProjectionElem::Field(FieldProjKind::Adt(id, _), _)) => {
269                // Trick, we don't know the generics but the projected place does, so
270                // we substitute it there, then update our current id.
271                let tref = inner.ty.as_adt().unwrap();
272                *id = *tref.id.as_adt().unwrap()
273            }
274            _ => {}
275        }
276    }
277
278    fn enter_global_decl_ref(&mut self, glob: &mut GlobalDeclRef) {
279        self.subst_use_glob(&mut glob.id, &mut glob.generics);
280    }
281}
282
283#[derive(Visitor)]
284#[allow(dead_code)]
285struct MissingIndexChecker<'a> {
286    krate: &'a TranslatedCrate,
287    current_item: Option<AnyTransItem<'a>>,
288}
289impl VisitAst for MissingIndexChecker<'_> {
290    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
291        if self.krate.fun_decls.get(*id).is_none() {
292            panic!(
293                "Missing function declaration for id: {:?}, in {:?}",
294                id, self.current_item
295            );
296        }
297    }
298
299    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
300        if self.krate.trait_impls.get(*id).is_none() {
301            panic!(
302                "Missing trait implementation for id: {:?}, in {:?}",
303                id, self.current_item
304            );
305        }
306    }
307
308    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
309        if self.krate.trait_decls.get(*id).is_none() {
310            panic!(
311                "Missing trait declaration for id: {:?}, in {:?}",
312                id, self.current_item
313            );
314        }
315    }
316
317    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
318        if self.krate.type_decls.get(*id).is_none() {
319            panic!(
320                "Missing type declaration for id: {:?}, in {:?}",
321                id, self.current_item
322            );
323        }
324    }
325}
326
327fn find_uses(data: &mut PassData, krate: &TranslatedCrate, item: &AnyTransItem) {
328    let mut visitor = UsageVisitor { data, krate };
329    let _ = item.drive(&mut visitor);
330}
331
332fn subst_uses<T: AstVisitable + Debug>(data: &PassData, item: &mut T) {
333    let mut visitor = SubstVisitor { data };
334    let _ = item.drive_mut(&mut visitor);
335}
336
337// fn check_missing_indices(krate: &TranslatedCrate) {
338//     let mut visitor = MissingIndexChecker {
339//         krate,
340//         current_item: None,
341//     };
342//     for item in krate.all_items() {
343//         visitor.current_item = Some(item);
344//         item.drive(&mut visitor);
345//     }
346// }
347
348// fn path_for_generics(gargs: &GenericArgs) -> PathElem {
349//     PathElem::Ident(gargs.to_string(), Disambiguator::ZERO)
350// }
351
352pub struct Transform;
353impl TransformPass for Transform {
354    fn transform_ctx(&self, ctx: &mut TransformCtx) {
355        // Check the option which instructs to ignore this pass
356        if !ctx.options.monomorphize {
357            return;
358        }
359
360        // From https://doc.rust-lang.org/nightly/nightly-rustc/rustc_monomorphize/collector/index.html#general-algorithm
361        //
362        // The purpose of the algorithm implemented in this module is to build the mono item
363        // graph for the current crate. It runs in two phases:
364        // 1. Discover the roots of the graph by traversing the HIR of the crate.
365        // 2. Starting from the roots, find uses by inspecting the MIR representation of the
366        //    item corresponding to a given node, until no more new nodes are found.
367        //
368        // The roots of the mono item graph correspond to the public non-generic syntactic
369        // items in the source code. We find them by walking the HIR of the crate, and whenever
370        // we hit upon a public function, method, or static item, we create a mono item
371        // consisting of the items DefId and, since we only consider non-generic items, an
372        // empty type-parameters set.
373        //
374        // Given a mono item node, we can discover uses by inspecting its MIR. We walk the MIR
375        // to find other mono items used by each mono item. Since the mono item we are
376        // currently at is always monomorphic, we also know the concrete type arguments of its
377        // used mono items. The specific forms a use can take in MIR are quite diverse: it
378        // includes calling functions/methods, taking a reference to a function/method, drop
379        // glue, and unsizing casts.
380
381        // In our version of the algorithm, we do the following:
382        // 1. Find all the roots, adding them to the worklist.
383        // 2. For each item in the worklist:
384        //    a. Find all the items it uses, adding them to the worklist and the generic
385        //      arguments to the item.
386        //    b. Mark the item as visited
387
388        // Final list of monomorphized items: { (poly item, generic args) -> mono item }
389        let mut data = PassData::new();
390
391        let empty_gargs = GenericArgs::empty();
392
393        // Find the roots of the mono item graph
394        for (id, item) in ctx.translated.all_items_with_ids() {
395            match item {
396                AnyTransItem::Fun(f) if f.signature.generics.is_empty() => {
397                    data.items
398                        .insert((id, empty_gargs.clone()), OptionHint::Some(id));
399                    data.worklist.push(id);
400                }
401                _ => {}
402            }
403        }
404
405        // Iterate over worklist -- these items are always monomorphic!
406        while let Some(id) = data.worklist.pop() {
407            if data.visited.contains(&id) {
408                continue;
409            }
410            data.visited.insert(id);
411
412            // 1. Find new uses
413            let Some(item) = ctx.translated.get_item(id) else {
414                panic!("Couldn't find item {:} in translated items.", id)
415            };
416            find_uses(&mut data, &ctx.translated, &item);
417
418            // 2. Iterate through all newly discovered uses
419            for ((id, gargs), mono) in data.items.iter_mut() {
420                if mono.is_some() {
421                    continue;
422                }
423
424                // a. Monomorphize the items if they're polymorphic, add them to the worklist
425                let new_mono = if gargs.is_empty() {
426                    *id
427                } else {
428                    match id {
429                        AnyTransId::Fun(_) => {
430                            let key_pair = (id.clone(), Box::new(gargs.clone()));
431                            let (AnyTransId::Fun(fun_id), gargs) = mono.hint_or(&key_pair) else {
432                                panic!("Unexpected ID type in hint_or");
433                            };
434                            let fun = ctx.translated.fun_decls.get(*fun_id).unwrap();
435                            let mut fun_sub = fun.clone().substitute(gargs);
436                            fun_sub.signature.generics = GenericParams::empty();
437                            fun_sub
438                                .item_meta
439                                .name
440                                .name
441                                .push(PathElem::Monomorphized(gargs.clone()));
442
443                            let fun_id_sub = ctx.translated.fun_decls.push_with(|id| {
444                                fun_sub.def_id = id;
445                                fun_sub
446                            });
447
448                            AnyTransId::Fun(fun_id_sub)
449                        }
450                        AnyTransId::Type(typ_id) => {
451                            let typ = ctx.translated.type_decls.get(*typ_id).unwrap();
452                            let mut typ_sub = typ.clone().substitute(gargs);
453                            typ_sub.generics = GenericParams::empty();
454                            typ_sub
455                                .item_meta
456                                .name
457                                .name
458                                .push(PathElem::Monomorphized(gargs.clone().into()));
459
460                            let typ_id_sub = ctx.translated.type_decls.push_with(|id| {
461                                typ_sub.def_id = id;
462                                typ_sub
463                            });
464
465                            AnyTransId::Type(typ_id_sub)
466                        }
467                        AnyTransId::Global(g_id) => {
468                            let Some(glob) = ctx.translated.global_decls.get(*g_id) else {
469                                // Something odd happened -- we ignore and move on
470                                *mono = OptionHint::Some(*id);
471                                warn!("Found a global that has no associated declaration");
472                                continue;
473                            };
474                            let mut glob_sub = glob.clone().substitute(gargs);
475                            glob_sub.generics = GenericParams::empty();
476                            glob_sub
477                                .item_meta
478                                .name
479                                .name
480                                .push(PathElem::Monomorphized(gargs.clone().into()));
481
482                            let init = ctx.translated.fun_decls.get(glob.init).unwrap();
483                            let mut init_sub = init.clone().substitute(gargs);
484                            init_sub.signature.generics = GenericParams::empty();
485                            init_sub
486                                .item_meta
487                                .name
488                                .name
489                                .push(PathElem::Monomorphized(gargs.clone().into()));
490
491                            let init_id_sub = ctx.translated.fun_decls.push_with(|id| {
492                                init_sub.def_id = id;
493                                glob_sub.init = id;
494                                init_sub
495                            });
496
497                            let g_id_sub = ctx.translated.global_decls.push_with(|id| {
498                                glob_sub.def_id = id;
499                                glob_sub
500                            });
501
502                            data.worklist.push(AnyTransId::Fun(init_id_sub));
503
504                            AnyTransId::Global(g_id_sub)
505                        }
506                        _ => todo!("Unhandled monomorphization target ID {:?}", id),
507                    }
508                };
509                trace!(
510                    "Mono: Monomorphized {:?} with {:?} to {:?}",
511                    id,
512                    gargs,
513                    new_mono
514                );
515                if id != &new_mono {
516                    trace!(" - From {:?}", ctx.translated.get_item(id.clone()));
517                    trace!(" - To {:?}", ctx.translated.get_item(new_mono.clone()));
518                }
519                *mono = OptionHint::Some(new_mono);
520                data.worklist.push(new_mono);
521
522                let item = ctx.translated.get_item(new_mono).unwrap();
523                ctx.translated
524                    .item_names
525                    .insert(new_mono, item.item_meta().name.clone());
526            }
527
528            // 3. Substitute all generics with the monomorphized versions
529            let Some(item) = ctx.translated.get_item_mut(id) else {
530                panic!("Couldn't find item {:} in translated items.", id)
531            };
532            match item {
533                AnyTransItemMut::Fun(f) => subst_uses(&data, f),
534                AnyTransItemMut::Type(t) => subst_uses(&data, t),
535                AnyTransItemMut::TraitImpl(t) => subst_uses(&data, t),
536                AnyTransItemMut::Global(g) => subst_uses(&data, g),
537                AnyTransItemMut::TraitDecl(t) => subst_uses(&data, t),
538            };
539        }
540
541        // Now, remove all polymorphic items from the translation context, as all their
542        // uses have been monomorphized and substituted
543        ctx.translated
544            .fun_decls
545            .retain(|f| data.visited.contains(&AnyTransId::Fun(f.def_id)));
546        ctx.translated
547            .type_decls
548            .retain(|t| data.visited.contains(&AnyTransId::Type(t.def_id)));
549        ctx.translated
550            .global_decls
551            .retain(|g| data.visited.contains(&AnyTransId::Global(g.def_id)));
552        // ctx.translated.trait_impls.retain(|t| t.generics.is_empty());
553
554        // TODO: Currently we don't update all TraitImpls/TraitDecls with the monomorphized versions
555        //       and removing the polymorphic ones, so this fails.
556        // Finally, ensure we didn't leave any IDs un-replaced
557        // check_missing_indices(&ctx.translated);
558    }
559}