charon_lib/transform/
monomorphize.rs

1//! # Micro-pass: monomorphize all functions and types; at the end of this pass, all functions and types are monomorphic.
2use derive_generic_visitor::*;
3use std::collections::{HashMap, HashSet};
4
5use crate::ast::*;
6use crate::transform::TransformCtx;
7use std::fmt::Debug;
8
9use super::ctx::TransformPass;
10
11enum OptionHint<T, H> {
12    Some(T),
13    None,
14    Hint(H),
15}
16
17impl<T, H> OptionHint<T, H> {
18    fn is_some(&self) -> bool {
19        match self {
20            OptionHint::Some(_) => true,
21            OptionHint::None => false,
22            OptionHint::Hint(_) => false,
23        }
24    }
25
26    fn hint_or<'a>(&'a self, hint: &'a H) -> &'a H {
27        match self {
28            OptionHint::Some(_) => hint,
29            OptionHint::None => hint,
30            OptionHint::Hint(h) => h,
31        }
32    }
33}
34
35struct PassData {
36    // Map of (poly item, generic args) -> mono item
37    // None indicates the item hasn't been monomorphized yet
38    items: HashMap<(AnyTransId, GenericArgs), OptionHint<AnyTransId, (AnyTransId, GenericArgs)>>,
39    worklist: Vec<AnyTransId>,
40    visited: HashSet<AnyTransId>,
41}
42
43impl PassData {
44    fn new() -> Self {
45        PassData {
46            items: HashMap::new(),
47            worklist: Vec::new(),
48            visited: HashSet::new(),
49        }
50    }
51}
52
53impl TranslatedCrate {
54    // FIXME(Nadrieril): implement type&tref normalization and use that instead
55    fn find_trait_impl_and_gargs(self: &Self, kind: &TraitRefKind) -> (&TraitImpl, GenericArgs) {
56        match kind {
57            TraitRefKind::TraitImpl(impl_id, gargs) => {
58                let trait_impl = self.trait_impls.get(*impl_id).unwrap();
59                (trait_impl, gargs.clone())
60            }
61            TraitRefKind::ParentClause(p, _, clause) => {
62                let (trait_impl, _) = self.find_trait_impl_and_gargs(p);
63                let t_ref = trait_impl.parent_trait_refs.get(*clause).unwrap();
64                self.find_trait_impl_and_gargs(&t_ref.kind)
65            }
66            _ => panic!("Unexpected trait reference kind"),
67        }
68    }
69}
70
71#[derive(Visitor)]
72struct UsageVisitor<'a> {
73    data: &'a mut PassData,
74    krate: &'a TranslatedCrate,
75}
76impl UsageVisitor<'_> {
77    fn found_use(
78        &mut self,
79        id: &AnyTransId,
80        gargs: &GenericArgs,
81        default: OptionHint<AnyTransId, (AnyTransId, GenericArgs)>,
82    ) {
83        trace!("Mono: Found use: {:?} / {:?}", id, gargs);
84        self.data
85            .items
86            .entry((*id, gargs.clone()))
87            .or_insert(default);
88    }
89    fn found_use_ty(&mut self, id: &TypeDeclId, gargs: &GenericArgs) {
90        self.found_use(&AnyTransId::Type(*id), gargs, OptionHint::None);
91    }
92    fn found_use_fn(&mut self, id: &FunDeclId, gargs: &GenericArgs) {
93        self.found_use(&AnyTransId::Fun(*id), gargs, OptionHint::None);
94    }
95    fn found_use_global_decl_ref(&mut self, id: &GlobalDeclId, gargs: &GenericArgs) {
96        self.found_use(&AnyTransId::Global(*id), gargs, OptionHint::None);
97    }
98    fn found_use_fn_hinted(
99        &mut self,
100        id: &FunDeclId,
101        gargs: &GenericArgs,
102        (h_id, h_args): (FunDeclId, GenericArgs),
103    ) {
104        self.found_use(
105            &AnyTransId::Fun(*id),
106            gargs,
107            OptionHint::Hint((AnyTransId::Fun(h_id), h_args)),
108        );
109    }
110}
111impl VisitAst for UsageVisitor<'_> {
112    fn enter_aggregate_kind(&mut self, kind: &AggregateKind) {
113        match kind {
114            AggregateKind::Adt(TypeId::Adt(id), _, _, gargs) => self.found_use_ty(id, gargs),
115            AggregateKind::Closure(fun_id, gargs) => self.found_use_fn(fun_id, gargs),
116            _ => {}
117        }
118    }
119
120    fn enter_ty_kind(&mut self, kind: &TyKind) {
121        match kind {
122            TyKind::Adt(TypeId::Adt(id), gargs) => {
123                self.found_use_ty(id, gargs);
124            }
125            _ => {}
126        }
127    }
128
129    fn enter_fn_ptr(&mut self, fn_ptr: &FnPtr) {
130        match &fn_ptr.func {
131            FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
132                self.found_use_fn(id, &fn_ptr.generics)
133            }
134            FunIdOrTraitMethodRef::Trait(t_ref, name, id) => {
135                let (trait_impl, impl_gargs) = self.krate.find_trait_impl_and_gargs(&t_ref.kind);
136                let (_, bound_fn) = trait_impl.methods().find(|(n, _)| n == name).unwrap();
137                let fn_ref: Binder<Binder<FunDeclRef>> = Binder::new(
138                    BinderKind::Other,
139                    trait_impl.generics.clone(),
140                    bound_fn.clone(),
141                );
142                // This is the actual function we need to call!
143                // Whereas id is the trait method reference(?)
144                let fn_ref = fn_ref.apply(&impl_gargs).apply(&fn_ptr.generics);
145                let gargs_key = fn_ptr.generics.clone().concat(
146                    GenericsSource::Builtin,
147                    &t_ref.trait_decl_ref.skip_binder.generics,
148                );
149                self.found_use_fn_hinted(id, &gargs_key, (fn_ref.id, fn_ref.generics))
150            }
151            // These can't be monomorphized, since they're builtins
152            FunIdOrTraitMethodRef::Fun(FunId::Builtin(..)) => {}
153        }
154    }
155
156    fn enter_global_decl_ref(&mut self, glob: &GlobalDeclRef) {
157        self.found_use_global_decl_ref(&glob.id, &glob.generics);
158    }
159}
160
161// Akin to UsageVisitor, but substitutes all uses of generics with the monomorphized versions
162// This is a two-step process, because we can't mutate the translation context with new definitions
163// while also mutating the existing definitions.
164#[derive(Visitor)]
165struct SubstVisitor<'a> {
166    data: &'a PassData,
167}
168impl SubstVisitor<'_> {
169    fn subst_use_ty(&mut self, id: &mut TypeDeclId, gargs: &mut GenericArgs) {
170        trace!("Mono: Subst use: {:?} / {:?}", id, gargs);
171        let key = (AnyTransId::Type(*id), gargs.clone());
172        let subst = self.data.items.get(&key);
173        let Some(OptionHint::Some(AnyTransId::Type(subst_id))) = subst else {
174            panic!("Substitution missing for {:?}", key);
175        };
176        *id = *subst_id;
177        // *gargs = GenericArgs::empty(GenericsSource::Builtin);
178    }
179    fn subst_use_fun(&mut self, id: &mut FunDeclId, gargs: &mut GenericArgs) {
180        trace!("Mono: Subst use: {:?} / {:?}", id, gargs);
181        let key = (AnyTransId::Fun(*id), gargs.clone());
182        let subst = self.data.items.get(&key);
183        let Some(OptionHint::Some(AnyTransId::Fun(subst_id))) = subst else {
184            panic!("Substitution missing for {:?}", key);
185        };
186        *id = *subst_id;
187        // *gargs = GenericArgs::empty(GenericsSource::Builtin);
188    }
189    fn subst_use_glob(&mut self, id: &mut GlobalDeclId, gargs: &mut GenericArgs) {
190        trace!("Mono: Subst use: {:?} / {:?}", id, gargs);
191        let key = (AnyTransId::Global(*id), gargs.clone());
192        let subst = self.data.items.get(&key);
193        let Some(OptionHint::Some(AnyTransId::Global(subst_id))) = subst else {
194            panic!("Substitution missing for {:?}", key);
195        };
196        *id = *subst_id;
197        // *gargs = GenericArgs::empty(GenericsSource::Builtin);
198    }
199}
200
201impl VisitAstMut for SubstVisitor<'_> {
202    fn exit_ullbc_statement(&mut self, stt: &mut ullbc_ast::Statement) {
203        match &mut stt.content {
204            ullbc_ast::RawStatement::Assign(_, Rvalue::Discriminant(Place { ty, .. }, id)) => {
205                // FIXME(Nadrieril): remove this id, replace with a helper fn
206                match ty.as_adt() {
207                    Some((TypeId::Adt(new_enum_id), _)) => {
208                        // Small trick; the discriminant doesn't carry the information on the
209                        // generics of the enum, since it's irrelevant, but we need it to do
210                        // the substitution, so we look at the type of the place we read from
211                        *id = new_enum_id;
212                    }
213                    _ => {}
214                }
215            }
216            _ => {}
217        }
218    }
219    fn exit_llbc_statement(&mut self, stt: &mut llbc_ast::Statement) {
220        match &mut stt.content {
221            llbc_ast::RawStatement::Assign(_, Rvalue::Discriminant(Place { ty, .. }, id)) => {
222                match ty.as_adt() {
223                    Some((TypeId::Adt(new_enum_id), _)) => {
224                        // Small trick; the discriminant doesn't carry the information on the
225                        // generics of the enum, since it's irrelevant, but we need it to do
226                        // the substitution, so we look at the type of the place we read from
227                        *id = new_enum_id;
228                    }
229                    _ => {}
230                }
231            }
232            _ => {}
233        }
234    }
235
236    fn enter_aggregate_kind(&mut self, kind: &mut AggregateKind) {
237        match kind {
238            AggregateKind::Adt(TypeId::Adt(id), _, _, gargs) => self.subst_use_ty(id, gargs),
239            AggregateKind::Closure(fun_id, gargs) => self.subst_use_fun(fun_id, gargs),
240            _ => {}
241        }
242    }
243
244    fn enter_ty_kind(&mut self, kind: &mut TyKind) {
245        match kind {
246            TyKind::Adt(TypeId::Adt(id), gargs) => {
247                self.subst_use_ty(id, gargs);
248            }
249            _ => {}
250        }
251    }
252
253    fn enter_fn_ptr(&mut self, fn_ptr: &mut FnPtr) {
254        match &mut fn_ptr.func {
255            FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id)) => {
256                self.subst_use_fun(fun_id, &mut fn_ptr.generics)
257            }
258            FunIdOrTraitMethodRef::Trait(t_ref, _, fun_id) => {
259                let mut gargs_key = fn_ptr.generics.clone().concat(
260                    GenericsSource::Builtin,
261                    &t_ref.trait_decl_ref.skip_binder.generics,
262                );
263                self.subst_use_fun(fun_id, &mut gargs_key);
264                fn_ptr.generics = gargs_key;
265            }
266            // These can't be monomorphized, since they're builtins
267            FunIdOrTraitMethodRef::Fun(FunId::Builtin(..)) => {}
268        }
269    }
270
271    fn exit_place(&mut self, place: &mut Place) {
272        match &mut place.kind {
273            // FIXME(Nadrieril): remove this id, replace with a helper fn
274            PlaceKind::Projection(inner, ProjectionElem::Field(FieldProjKind::Adt(id, _), _)) => {
275                // Trick, we don't know the generics but the projected place does, so
276                // we substitute it there, then update our current id.
277                let (inner_id, _) = inner.ty.as_adt().unwrap();
278                *id = *inner_id.as_adt().unwrap()
279            }
280            _ => {}
281        }
282    }
283
284    fn enter_global_decl_ref(&mut self, glob: &mut GlobalDeclRef) {
285        self.subst_use_glob(&mut glob.id, &mut glob.generics);
286    }
287}
288
289#[derive(Visitor)]
290#[allow(dead_code)]
291struct MissingIndexChecker<'a> {
292    krate: &'a TranslatedCrate,
293    current_item: Option<AnyTransItem<'a>>,
294}
295impl VisitAst for MissingIndexChecker<'_> {
296    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
297        if self.krate.fun_decls.get(*id).is_none() {
298            panic!(
299                "Missing function declaration for id: {:?}, in {:?}",
300                id, self.current_item
301            );
302        }
303    }
304
305    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
306        if self.krate.trait_impls.get(*id).is_none() {
307            panic!(
308                "Missing trait implementation for id: {:?}, in {:?}",
309                id, self.current_item
310            );
311        }
312    }
313
314    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
315        if self.krate.trait_decls.get(*id).is_none() {
316            panic!(
317                "Missing trait declaration for id: {:?}, in {:?}",
318                id, self.current_item
319            );
320        }
321    }
322
323    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
324        if self.krate.type_decls.get(*id).is_none() {
325            panic!(
326                "Missing type declaration for id: {:?}, in {:?}",
327                id, self.current_item
328            );
329        }
330    }
331}
332
333fn find_uses(data: &mut PassData, krate: &TranslatedCrate, item: &AnyTransItem) {
334    let mut visitor = UsageVisitor { data, krate };
335    let _ = item.drive(&mut visitor);
336}
337
338fn subst_uses<T: AstVisitable + Debug>(data: &PassData, item: &mut T) {
339    let mut visitor = SubstVisitor { data };
340    let _ = item.drive_mut(&mut visitor);
341}
342
343// fn check_missing_indices(krate: &TranslatedCrate) {
344//     let mut visitor = MissingIndexChecker {
345//         krate,
346//         current_item: None,
347//     };
348//     for item in krate.all_items() {
349//         visitor.current_item = Some(item);
350//         item.drive(&mut visitor);
351//     }
352// }
353
354// fn path_for_generics(gargs: &GenericArgs) -> PathElem {
355//     PathElem::Ident(gargs.to_string(), Disambiguator::ZERO)
356// }
357
358impl GenericArgs {
359    fn is_empty_ignoring_regions(&self) -> bool {
360        let GenericArgs {
361            types,
362            const_generics,
363            trait_refs,
364            regions: _,
365            target: _,
366        } = self;
367        let len = types.elem_count() + const_generics.elem_count() + trait_refs.elem_count();
368        len == 0
369    }
370}
371
372// impl GenericParams {
373//     fn is_empty_ignoring_regions(&self) -> bool {
374//         let GenericParams {
375//             regions: _,
376//             types,
377//             const_generics,
378//             trait_clauses,
379//             regions_outlive: _,
380//             types_outlive,
381//             trait_type_constraints,
382//         } = self;
383//         let len = types.elem_count()
384//             + const_generics.elem_count()
385//             + trait_clauses.elem_count()
386//             + trait_type_constraints.elem_count()
387//             + types_outlive.len();
388//         len == 0
389//     }
390// }
391
392pub struct Transform;
393impl TransformPass for Transform {
394    fn transform_ctx(&self, ctx: &mut TransformCtx) {
395        // Check the option which instructs to ignore this pass
396        if !ctx.options.monomorphize {
397            return;
398        }
399
400        // From https://doc.rust-lang.org/nightly/nightly-rustc/rustc_monomorphize/collector/index.html#general-algorithm
401        //
402        // The purpose of the algorithm implemented in this module is to build the mono item
403        // graph for the current crate. It runs in two phases:
404        // 1. Discover the roots of the graph by traversing the HIR of the crate.
405        // 2. Starting from the roots, find uses by inspecting the MIR representation of the
406        //    item corresponding to a given node, until no more new nodes are found.
407        //
408        // The roots of the mono item graph correspond to the public non-generic syntactic
409        // items in the source code. We find them by walking the HIR of the crate, and whenever
410        // we hit upon a public function, method, or static item, we create a mono item
411        // consisting of the items DefId and, since we only consider non-generic items, an
412        // empty type-parameters set.
413        //
414        // Given a mono item node, we can discover uses by inspecting its MIR. We walk the MIR
415        // to find other mono items used by each mono item. Since the mono item we are
416        // currently at is always monomorphic, we also know the concrete type arguments of its
417        // used mono items. The specific forms a use can take in MIR are quite diverse: it
418        // includes calling functions/methods, taking a reference to a function/method, drop
419        // glue, and unsizing casts.
420
421        // In our version of the algorithm, we do the following:
422        // 1. Find all the roots, adding them to the worklist.
423        // 2. For each item in the worklist:
424        //    a. Find all the items it uses, adding them to the worklist and the generic
425        //      arguments to the item.
426        //    b. Mark the item as visited
427
428        // Final list of monomorphized items: { (poly item, generic args) -> mono item }
429        let mut data = PassData::new();
430
431        let empty_gargs = GenericArgs::empty(GenericsSource::Builtin);
432
433        // Find the roots of the mono item graph
434        for (id, item) in ctx.translated.all_items_with_ids() {
435            match item {
436                AnyTransItem::Fun(f) if f.signature.generics.is_empty() => {
437                    data.items
438                        .insert((id, empty_gargs.clone()), OptionHint::Some(id));
439                    data.worklist.push(id);
440                }
441                _ => {}
442            }
443        }
444
445        // Iterate over worklist -- these items are always monomorphic!
446        while let Some(id) = data.worklist.pop() {
447            if data.visited.contains(&id) {
448                continue;
449            }
450            data.visited.insert(id);
451
452            // 1. Find new uses
453            let Some(item) = ctx.translated.get_item(id) else {
454                panic!("Couldn't find item {:} in translated items.", id)
455            };
456            find_uses(&mut data, &ctx.translated, &item);
457
458            // 2. Iterate through all newly discovered uses
459            for ((id, gargs), mono) in data.items.iter_mut() {
460                if mono.is_some() {
461                    continue;
462                }
463
464                // a. Monomorphize the items if they're polymorphic, add them to the worklist
465                let new_mono = if gargs.is_empty_ignoring_regions() {
466                    *id
467                } else {
468                    match id {
469                        AnyTransId::Fun(_) => {
470                            let key_pair = (id.clone(), gargs.clone());
471                            let (AnyTransId::Fun(fun_id), gargs) = mono.hint_or(&key_pair) else {
472                                panic!("Unexpected ID type in hint_or");
473                            };
474                            let fun = ctx.translated.fun_decls.get(*fun_id).unwrap();
475                            let mut fun_sub = fun.clone().substitute(gargs);
476                            // fun_sub.signature.generics = GenericParams::empty();
477
478                            let fun_id_sub = ctx.translated.fun_decls.push_with(|id| {
479                                fun_sub.def_id = id;
480                                fun_sub
481                            });
482
483                            AnyTransId::Fun(fun_id_sub)
484                        }
485                        AnyTransId::Type(typ_id) => {
486                            let typ = ctx.translated.type_decls.get(*typ_id).unwrap();
487                            let mut typ_sub = typ.clone().substitute(gargs);
488
489                            let typ_id_sub = ctx.translated.type_decls.push_with(|id| {
490                                typ_sub.def_id = id;
491                                typ_sub
492                            });
493
494                            AnyTransId::Type(typ_id_sub)
495                        }
496                        AnyTransId::Global(g_id) => {
497                            let glob = ctx.translated.global_decls.get(*g_id).unwrap();
498                            let mut glob_sub = glob.clone().substitute(gargs);
499
500                            let init = ctx.translated.fun_decls.get(glob.init).unwrap();
501                            let mut init_sub = init.clone().substitute(gargs);
502
503                            let init_id_sub = ctx.translated.fun_decls.push_with(|id| {
504                                init_sub.def_id = id;
505                                glob_sub.init = id;
506                                init_sub
507                            });
508
509                            let g_id_sub = ctx.translated.global_decls.push_with(|id| {
510                                glob_sub.def_id = id;
511                                glob_sub
512                            });
513
514                            data.worklist.push(AnyTransId::Fun(init_id_sub));
515
516                            AnyTransId::Global(g_id_sub)
517                        }
518                        _ => todo!("Unhandled monomorphization target ID {:?}", id),
519                    }
520                };
521                trace!(
522                    "Mono: Monomorphized {:?} with {:?} to {:?}",
523                    id,
524                    gargs,
525                    new_mono
526                );
527                if id != &new_mono {
528                    trace!(" - From {:?}", ctx.translated.get_item(id.clone()));
529                    trace!(" - To {:?}", ctx.translated.get_item(new_mono.clone()));
530                }
531                *mono = OptionHint::Some(new_mono);
532                data.worklist.push(new_mono);
533            }
534
535            // 3. Substitute all generics with the monomorphized versions
536            let Some(item) = ctx.translated.get_item_mut(id) else {
537                panic!("Couldn't find item {:} in translated items.", id)
538            };
539            match item {
540                AnyTransItemMut::Fun(f) => subst_uses(&data, f),
541                AnyTransItemMut::Type(t) => subst_uses(&data, t),
542                AnyTransItemMut::TraitImpl(t) => subst_uses(&data, t),
543                AnyTransItemMut::Global(g) => subst_uses(&data, g),
544                AnyTransItemMut::TraitDecl(t) => subst_uses(&data, t),
545            };
546        }
547
548        // Now, remove all polymorphic items from the translation context, as all their
549        // uses have been monomorphized and substituted
550        ctx.translated
551            .fun_decls
552            .retain(|f| data.visited.contains(&AnyTransId::Fun(f.def_id)));
553        ctx.translated
554            .type_decls
555            .retain(|t| data.visited.contains(&AnyTransId::Type(t.def_id)));
556        ctx.translated
557            .global_decls
558            .retain(|g| data.visited.contains(&AnyTransId::Global(g.def_id)));
559        // ctx.translated.trait_impls.retain(|t| t.generics.is_empty());
560
561        // TODO: Currently we don't update all TraitImpls/TraitDecls with the monomorphized versions
562        //       and removing the polymorphic ones, so this fails.
563        // Finally, ensure we didn't leave any IDs un-replaced
564        // check_missing_indices(&ctx.translated);
565    }
566}