charon_lib/transform/add_missing_info/
reorder_decls.rs

1//! Compute an ordering on declarations that:
2//! - Detects mutually-recursive groups;
3//! - Always orders an item before any of its uses (except for recursive cases);
4//! - Otherwise keeps a stable order.
5//!
6//! Aeneas needs this because proof assistant languages are sensitive to declaration order and need
7//! to be explicit about mutual recursion. This should come useful for translation to any other
8//! language with these properties.
9use crate::common::*;
10use crate::formatter::{FmtCtx, IntoFormatter};
11use crate::pretty::FmtWithCtx;
12use crate::transform::TransformCtx;
13use crate::ullbc_ast::*;
14use derive_generic_visitor::*;
15use indexmap::{IndexMap, IndexSet};
16use itertools::Itertools;
17use petgraph::algo::tarjan_scc;
18use petgraph::graphmap::DiGraphMap;
19use std::fmt::{Debug, Display, Error};
20use std::vec::Vec;
21
22use super::sccs::*;
23use crate::transform::ctx::TransformPass;
24
25impl<Id: Copy> GDeclarationGroup<Id> {
26    pub fn get_ids(&self) -> &[Id] {
27        use GDeclarationGroup::*;
28        match self {
29            NonRec(id) => std::slice::from_ref(id),
30            Rec(ids) => ids.as_slice(),
31        }
32    }
33
34    pub fn get_any_trans_ids(&self) -> Vec<ItemId>
35    where
36        Id: Into<ItemId>,
37    {
38        self.get_ids().iter().copied().map(|id| id.into()).collect()
39    }
40
41    fn make_group(is_rec: bool, gr: impl Iterator<Item = ItemId>) -> Self
42    where
43        Id: TryFrom<ItemId>,
44        Id::Error: Debug,
45    {
46        let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
47        if is_rec {
48            GDeclarationGroup::Rec(gr)
49        } else {
50            assert!(gr.len() == 1);
51            GDeclarationGroup::NonRec(gr[0])
52        }
53    }
54
55    fn to_mixed(&self) -> GDeclarationGroup<ItemId>
56    where
57        Id: Into<ItemId>,
58    {
59        match self {
60            GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
61            GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
62        }
63    }
64}
65
66impl DeclarationGroup {
67    pub fn to_mixed_group(&self) -> GDeclarationGroup<ItemId> {
68        use DeclarationGroup::*;
69        match self {
70            Type(gr) => gr.to_mixed(),
71            Fun(gr) => gr.to_mixed(),
72            Global(gr) => gr.to_mixed(),
73            TraitDecl(gr) => gr.to_mixed(),
74            TraitImpl(gr) => gr.to_mixed(),
75            Mixed(gr) => gr.clone(),
76        }
77    }
78
79    pub fn get_ids(&self) -> Vec<ItemId> {
80        use DeclarationGroup::*;
81        match self {
82            Type(gr) => gr.get_any_trans_ids(),
83            Fun(gr) => gr.get_any_trans_ids(),
84            Global(gr) => gr.get_any_trans_ids(),
85            TraitDecl(gr) => gr.get_any_trans_ids(),
86            TraitImpl(gr) => gr.get_any_trans_ids(),
87            Mixed(gr) => gr.get_any_trans_ids(),
88        }
89    }
90}
91
92impl<Id: Display> Display for GDeclarationGroup<Id> {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
94        match self {
95            GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
96            GDeclarationGroup::Rec(ids) => {
97                write!(
98                    f,
99                    "rec: {}",
100                    pretty_display_list(|id| format!("    {id}"), ids)
101                )
102            }
103        }
104    }
105}
106
107impl Display for DeclarationGroup {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
109        match self {
110            DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
111            DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
112            DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
113            DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
114            DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
115            DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
116        }
117    }
118}
119
120#[derive(Visitor)]
121pub struct Deps {
122    dgraph: DiGraphMap<ItemId, ()>,
123    // Want to make sure we remember the order of insertion
124    graph: IndexMap<ItemId, IndexSet<ItemId>>,
125    // We use this when computing the graph
126    current_id: Option<ItemId>,
127    // We use this to track the trait impl block the current item belongs to
128    // (if relevant).
129    //
130    // We use this to ignore the references to the parent impl block.
131    //
132    // If we don't do so, when computing our dependency graph we end up with
133    // mutually recursive trait impl blocks/trait method impls in the presence
134    // of associated types (the deepest reason is that we don't normalize the
135    // types we query from rustc when translating the types from function
136    // signatures - we avoid doing so because as of now it makes resolving
137    // the trait params harder: if we get normalized types, we have to
138    // implement a normalizer on our side to make sure we correctly match
139    // types...).
140    //
141    //
142    // For instance, the problem happens if in Rust we have:
143    // ```text
144    // pub trait WithConstTy {
145    //     type W;
146    //     fn f(x: &mut Self::W);
147    // }
148    //
149    // impl WithConstTy for bool {
150    //     type W = u64;
151    //     fn f(_: &mut Self::W) {}
152    // }
153    // ```
154    //
155    // In LLBC we get:
156    //
157    // ```text
158    // impl traits::Bool::0 : traits::WithConstTy<bool>
159    // {
160    //     type W = u64 with []
161    //     fn f = traits::Bool::0::f
162    // }
163    //
164    // fn traits::Bool::0::f<@R0>(@1: &@R0 mut (traits::Bool::0::W)) { .. }
165    // //                                       ^^^^^^^^^^^^^^^
166    // //                                    refers to the trait impl
167    // ```
168    parent_trait_impl: Option<TraitImplId>,
169    parent_trait_decl: Option<TraitDeclId>,
170}
171
172impl Deps {
173    fn new() -> Self {
174        Deps {
175            dgraph: DiGraphMap::new(),
176            graph: IndexMap::new(),
177            current_id: None,
178            parent_trait_impl: None,
179            parent_trait_decl: None,
180        }
181    }
182
183    fn set_impl_or_trait_id(&mut self, kind: &ItemSource) {
184        match kind {
185            ItemSource::TraitDecl { trait_ref, .. } => self.parent_trait_decl = Some(trait_ref.id),
186            ItemSource::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.id),
187            _ => {}
188        }
189    }
190    fn set_current_id(&mut self, ctx: &TransformCtx, id: ItemId) {
191        self.insert_node(id);
192        self.current_id = Some(id);
193
194        // Add the id of the impl/trait this item belongs to, if necessary
195        use ItemId::*;
196        match id {
197            TraitDecl(_) | TraitImpl(_) | Type(_) => (),
198            Global(id) => {
199                if let Some(decl) = ctx.translated.global_decls.get(id) {
200                    self.set_impl_or_trait_id(&decl.src);
201                }
202            }
203            Fun(id) => {
204                if let Some(decl) = ctx.translated.fun_decls.get(id) {
205                    self.set_impl_or_trait_id(&decl.src);
206                }
207            }
208        }
209    }
210
211    fn unset_current_id(&mut self) {
212        self.current_id = None;
213        self.parent_trait_impl = None;
214        self.parent_trait_decl = None;
215    }
216
217    fn insert_node(&mut self, id: ItemId) {
218        // We have to be careful about duplicate nodes
219        if !self.dgraph.contains_node(id) {
220            self.dgraph.add_node(id);
221            assert!(!self.graph.contains_key(&id));
222            self.graph.insert(id, IndexSet::new());
223        }
224    }
225
226    fn insert_edge(&mut self, id1: ItemId) {
227        let id0 = self.current_id.unwrap();
228        self.insert_node(id1);
229        if !self.dgraph.contains_edge(id0, id1) {
230            self.dgraph.add_edge(id0, id1, ());
231            self.graph.get_mut(&id0).unwrap().insert(id1);
232        }
233    }
234}
235
236impl VisitAst for Deps {
237    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
238        self.insert_edge((*id).into());
239    }
240
241    fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
242        self.insert_edge((*id).into());
243    }
244
245    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
246        // If the impl is the impl this item belongs to, we ignore it
247        // TODO: this is not very satisfying but this is the only way we have of preventing
248        // mutually recursive groups between method impls and trait impls in the presence of
249        // associated types...
250        if let Some(impl_id) = &self.parent_trait_impl
251            && impl_id == id
252        {
253            return;
254        }
255        self.insert_edge((*id).into());
256    }
257
258    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
259        // If the trait is the trait this item belongs to, we ignore it. This is to avoid mutually
260        // recursive groups between e.g. traits decls and their globals. We treat methods
261        // specifically.
262        if let Some(trait_id) = &self.parent_trait_decl
263            && trait_id == id
264        {
265            return;
266        }
267        self.insert_edge((*id).into());
268    }
269
270    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
271        self.insert_edge((*id).into());
272    }
273
274    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
275        // Don't look inside because trait impls contain their own id in their name.
276        Continue(())
277    }
278    fn visit_item_source(&mut self, _: &ItemSource) -> ControlFlow<Self::Break> {
279        // Don't look inside to avoid recording a dependency from a method impl to the impl block
280        // it belongs to.
281        Continue(())
282    }
283}
284
285impl Deps {
286    fn fmt_with_ctx(&self, ctx: &FmtCtx<'_>) -> String {
287        self.dgraph
288            .nodes()
289            .map(|node| {
290                let edges = self
291                    .dgraph
292                    .edges(node)
293                    .map(|e| format!("\n  {}", e.1.with_ctx(ctx)))
294                    .collect::<Vec<String>>()
295                    .join(",");
296
297                format!("{} -> [{}\n]", node.with_ctx(ctx), edges)
298            })
299            .format(",\n")
300            .to_string()
301    }
302}
303
304fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
305    // The order we explore the items in will dictate the final order. We do the following: std
306    // items, then items from foreign crates (sorted by crate name), then local items. Within a
307    // crate, we sort by file then by source order.
308    let mut sorted_items = ctx.translated.all_items().collect_vec();
309    // Pre-sort files to avoid costly string comparisons. Maps file ids to an index that reflects
310    // ordering on the crates (with `core` and `std` sorted first) and file names.
311    let sorted_file_ids: Vector<FileId, usize> = ctx
312        .translated
313        .files
314        .all_indices()
315        .sorted_by_key(|&file_id| {
316            let file = &ctx.translated.files[file_id];
317            let is_std = file.crate_name == "std" || file.crate_name == "core";
318            (!is_std, &file.crate_name, &file.name)
319        })
320        .enumerate()
321        .sorted_by_key(|(_i, file_id)| *file_id)
322        .map(|(i, _file_id)| i)
323        .collect();
324    assert_eq!(
325        ctx.translated.files.slot_count(),
326        sorted_file_ids.slot_count()
327    );
328    sorted_items.sort_by_key(|item| {
329        let item_meta = item.item_meta();
330        let span = item_meta.span.data;
331        let file_name_order = sorted_file_ids[span.file_id];
332        (item_meta.is_local, file_name_order, span.beg)
333    });
334
335    let mut graph = Deps::new();
336    for item in sorted_items {
337        let id = item.id();
338        graph.set_current_id(ctx, id);
339        match item {
340            ItemRef::Type(..) | ItemRef::TraitImpl(..) | ItemRef::Global(..) => {
341                let _ = item.drive(&mut graph);
342            }
343            ItemRef::Fun(d) => {
344                // Skip `d.is_global_initializer` to avoid incorrect mutual dependencies.
345                // TODO: add `is_global_initializer` to `ItemKind`.
346                let _ = d.signature.drive(&mut graph);
347                let _ = d.body.drive(&mut graph);
348                // FIXME(#514): A method declaration depends on its declaring trait because of its
349                // `Self` clause. While the clause is implicit, we make sure to record the
350                // dependency manually.
351                if let ItemSource::TraitDecl { trait_ref, .. } = &d.src {
352                    graph.insert_edge(trait_ref.id.into());
353                }
354            }
355            ItemRef::TraitDecl(d) => {
356                let TraitDecl {
357                    def_id: _,
358                    item_meta: _,
359                    generics,
360                    implied_clauses: parent_clauses,
361                    consts,
362                    types,
363                    methods,
364                    vtable,
365                } = d;
366                // Visit the traits referenced in the generics
367                let _ = generics.drive(&mut graph);
368
369                // Visit the parent clauses
370                let _ = parent_clauses.drive(&mut graph);
371
372                // Visit the items
373                let _ = types.drive(&mut graph);
374                let _ = vtable.drive(&mut graph);
375
376                // We consider that a trait decl only contains the function/constant signatures.
377                // Therefore we don't explore the default const/method ids.
378                for assoc_const in consts {
379                    let TraitAssocConst {
380                        name: _,
381                        ty,
382                        default,
383                    } = assoc_const;
384                    let _ = ty.drive(&mut graph);
385                    if let Some(gref) = default {
386                        let _ = gref.generics.drive(&mut graph);
387                    }
388                }
389                for bound_method in methods {
390                    let id = bound_method.skip_binder.item.id;
391                    let _ = bound_method.params.drive(&mut graph);
392                    if let Some(decl) = ctx.translated.fun_decls.get(id) {
393                        let _ = decl.signature.drive(&mut graph);
394                    }
395                }
396            }
397        }
398        graph.unset_current_id();
399    }
400    graph
401}
402
403fn group_declarations_from_scc(
404    _ctx: &TransformCtx,
405    graph: Deps,
406    reordered_sccs: SCCs<ItemId>,
407) -> DeclarationsGroups {
408    let reordered_sccs = &reordered_sccs.sccs;
409    let mut reordered_decls: DeclarationsGroups = Vec::new();
410
411    // Iterate over the SCC ids in the proper order
412    for scc in reordered_sccs.iter() {
413        if scc.is_empty() {
414            // This can happen if we failed to translate the item in this group.
415            continue;
416        }
417
418        // Note that the length of an SCC should be at least 1.
419        let mut it = scc.iter();
420        let id0 = *it.next().unwrap();
421        let decl = graph.graph.get(&id0).unwrap();
422
423        // If an SCC has length one, the declaration may be simply recursive:
424        // we determine whether it is the case by checking if the def id is in
425        // its own set of dependencies.
426        let is_mutually_recursive = scc.len() > 1;
427        let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
428        let is_rec = is_mutually_recursive || is_simply_recursive;
429
430        let all_same_kind = scc
431            .iter()
432            .all(|id| id0.variant_index_arity() == id.variant_index_arity());
433        let ids = scc.iter().copied();
434        let group: DeclarationGroup = match id0 {
435            _ if !all_same_kind => {
436                DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
437            }
438            ItemId::Type(_) => DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids)),
439            ItemId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
440            ItemId::Global(_) => {
441                DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
442            }
443            ItemId::TraitDecl(_) => {
444                let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
445                // Trait declarations often refer to `Self`, like below,
446                // which means they are often considered as recursive by our
447                // analysis. TODO: do something more precise. What is important
448                // is that we never use the "whole" self clause as argument,
449                // but rather projections over the self clause (like `<Self as Foo>::u`,
450                // in the declaration for `Foo`).
451                if gr.len() == 1 {
452                    DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
453                } else {
454                    DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
455                }
456            }
457            ItemId::TraitImpl(_) => {
458                DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
459            }
460        };
461
462        reordered_decls.push(group);
463    }
464    reordered_decls
465}
466
467fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
468    trace!();
469
470    // Step 1: explore the declarations to build the graph
471    let graph = compute_declarations_graph(ctx);
472    trace!("Graph:\n{}\n", graph.fmt_with_ctx(&ctx.into_fmt()));
473
474    // Step 2: Apply Tarjan's SCC (Strongly Connected Components) algorithm
475    let sccs = tarjan_scc(&graph.dgraph);
476
477    // Step 3: Reorder the declarations in an order as close as possible to the one
478    // given by the user. To be more precise, if we don't need to move
479    // definitions, the order in which we generate the declarations should
480    // be the same as the one in which the user wrote them.
481    // Remark: the [get_id_dependencies] function will be called once per id, meaning
482    // it is ok if it is not very efficient and clones values.
483    let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
484    let all_ids: Vec<ItemId> = graph
485        .graph
486        .keys()
487        .copied()
488        // Don't list ids that weren't translated.
489        .filter(|id| ctx.translated.get_item(*id).is_some())
490        .collect();
491    let reordered_sccs = reorder_sccs::<ItemId>(get_id_dependencies, &all_ids, &sccs);
492
493    // Finally, generate the list of declarations
494    let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
495
496    trace!("{:?}", reordered_decls);
497    reordered_decls
498}
499
500pub struct Transform;
501impl TransformPass for Transform {
502    fn transform_ctx(&self, ctx: &mut TransformCtx) {
503        let reordered_decls = compute_reordered_decls(&ctx);
504        ctx.translated.ordered_decls = Some(reordered_decls);
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    #[test]
511    fn test_reorder_sccs1() {
512        let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
513        let ids = vec![0, 1, 2, 3, 4, 5];
514
515        let get_deps = &|x| match x {
516            0 => vec![3],
517            1 => vec![0, 3],
518            _ => vec![],
519        };
520        let reordered = super::reorder_sccs(get_deps, &ids, &sccs);
521
522        assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
523    }
524}