charon_lib/transform/add_missing_info/
reorder_decls.rs

1//! Compute an ordering on declarations that:
2//! - Detects mutually-recursive groups;
3//! - Always orders an item before any of its uses (except for recursive cases);
4//! - Otherwise keeps a stable order.
5//!
6//! Aeneas needs this because proof assistant languages are sensitive to declaration order and need
7//! to be explicit about mutual recursion. This should come useful for translation to any other
8//! language with these properties.
9use crate::common::*;
10use crate::formatter::{FmtCtx, IntoFormatter};
11use crate::pretty::FmtWithCtx;
12use crate::transform::TransformCtx;
13use crate::ullbc_ast::*;
14use derive_generic_visitor::*;
15use itertools::Itertools;
16use petgraph::algo::tarjan_scc;
17use petgraph::graphmap::DiGraphMap;
18use std::fmt::{Debug, Display, Error};
19use std::vec::Vec;
20
21use super::sccs::*;
22use crate::transform::ctx::TransformPass;
23
24impl<Id: Copy> GDeclarationGroup<Id> {
25    pub fn get_ids(&self) -> &[Id] {
26        use GDeclarationGroup::*;
27        match self {
28            NonRec(id) => std::slice::from_ref(id),
29            Rec(ids) => ids.as_slice(),
30        }
31    }
32
33    pub fn get_any_trans_ids(&self) -> Vec<ItemId>
34    where
35        Id: Into<ItemId>,
36    {
37        self.get_ids().iter().copied().map(|id| id.into()).collect()
38    }
39
40    fn make_group(is_rec: bool, gr: impl Iterator<Item = ItemId>) -> Self
41    where
42        Id: TryFrom<ItemId>,
43        Id::Error: Debug,
44    {
45        let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
46        if is_rec {
47            GDeclarationGroup::Rec(gr)
48        } else {
49            assert!(gr.len() == 1);
50            GDeclarationGroup::NonRec(gr[0])
51        }
52    }
53
54    fn to_mixed(&self) -> GDeclarationGroup<ItemId>
55    where
56        Id: Into<ItemId>,
57    {
58        match self {
59            GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
60            GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
61        }
62    }
63}
64
65impl DeclarationGroup {
66    pub fn to_mixed_group(&self) -> GDeclarationGroup<ItemId> {
67        use DeclarationGroup::*;
68        match self {
69            Type(gr) => gr.to_mixed(),
70            Fun(gr) => gr.to_mixed(),
71            Global(gr) => gr.to_mixed(),
72            TraitDecl(gr) => gr.to_mixed(),
73            TraitImpl(gr) => gr.to_mixed(),
74            Mixed(gr) => gr.clone(),
75        }
76    }
77
78    pub fn get_ids(&self) -> Vec<ItemId> {
79        use DeclarationGroup::*;
80        match self {
81            Type(gr) => gr.get_any_trans_ids(),
82            Fun(gr) => gr.get_any_trans_ids(),
83            Global(gr) => gr.get_any_trans_ids(),
84            TraitDecl(gr) => gr.get_any_trans_ids(),
85            TraitImpl(gr) => gr.get_any_trans_ids(),
86            Mixed(gr) => gr.get_any_trans_ids(),
87        }
88    }
89}
90
91impl<Id: Display> Display for GDeclarationGroup<Id> {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
93        match self {
94            GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
95            GDeclarationGroup::Rec(ids) => {
96                write!(
97                    f,
98                    "rec: {}",
99                    pretty_display_list(|id| format!("    {id}"), ids)
100                )
101            }
102        }
103    }
104}
105
106impl Display for DeclarationGroup {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
108        match self {
109            DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
110            DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
111            DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
112            DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
113            DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
114            DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
115        }
116    }
117}
118
119#[derive(Visitor)]
120pub struct Deps {
121    dgraph: DiGraphMap<ItemId, ()>,
122    // Want to make sure we remember the order of insertion
123    graph: SeqHashMap<ItemId, SeqHashSet<ItemId>>,
124    // We use this when computing the graph
125    current_id: Option<ItemId>,
126    // We use this to track the trait impl block the current item belongs to
127    // (if relevant).
128    //
129    // We use this to ignore the references to the parent impl block.
130    //
131    // If we don't do so, when computing our dependency graph we end up with
132    // mutually recursive trait impl blocks/trait method impls in the presence
133    // of associated types (the deepest reason is that we don't normalize the
134    // types we query from rustc when translating the types from function
135    // signatures - we avoid doing so because as of now it makes resolving
136    // the trait params harder: if we get normalized types, we have to
137    // implement a normalizer on our side to make sure we correctly match
138    // types...).
139    //
140    //
141    // For instance, the problem happens if in Rust we have:
142    // ```text
143    // pub trait WithConstTy {
144    //     type W;
145    //     fn f(x: &mut Self::W);
146    // }
147    //
148    // impl WithConstTy for bool {
149    //     type W = u64;
150    //     fn f(_: &mut Self::W) {}
151    // }
152    // ```
153    //
154    // In LLBC we get:
155    //
156    // ```text
157    // impl traits::Bool::0 : traits::WithConstTy<bool>
158    // {
159    //     type W = u64 with []
160    //     fn f = traits::Bool::0::f
161    // }
162    //
163    // fn traits::Bool::0::f<@R0>(@1: &@R0 mut (traits::Bool::0::W)) { .. }
164    // //                                       ^^^^^^^^^^^^^^^
165    // //                                    refers to the trait impl
166    // ```
167    parent_trait_impl: Option<TraitImplId>,
168    parent_trait_decl: Option<TraitDeclId>,
169}
170
171impl Deps {
172    fn new() -> Self {
173        Deps {
174            dgraph: DiGraphMap::new(),
175            graph: SeqHashMap::new(),
176            current_id: None,
177            parent_trait_impl: None,
178            parent_trait_decl: None,
179        }
180    }
181
182    fn set_impl_or_trait_id(&mut self, kind: &ItemSource) {
183        match kind {
184            ItemSource::TraitDecl { trait_ref, .. } => self.parent_trait_decl = Some(trait_ref.id),
185            ItemSource::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.id),
186            _ => {}
187        }
188    }
189    fn set_current_id(&mut self, ctx: &TransformCtx, id: ItemId) {
190        self.insert_node(id);
191        self.current_id = Some(id);
192
193        // Add the id of the impl/trait this item belongs to, if necessary
194        use ItemId::*;
195        match id {
196            TraitDecl(_) | TraitImpl(_) | Type(_) => (),
197            Global(id) => {
198                if let Some(decl) = ctx.translated.global_decls.get(id) {
199                    self.set_impl_or_trait_id(&decl.src);
200                }
201            }
202            Fun(id) => {
203                if let Some(decl) = ctx.translated.fun_decls.get(id) {
204                    self.set_impl_or_trait_id(&decl.src);
205                }
206            }
207        }
208    }
209
210    fn unset_current_id(&mut self) {
211        self.current_id = None;
212        self.parent_trait_impl = None;
213        self.parent_trait_decl = None;
214    }
215
216    fn insert_node(&mut self, id: ItemId) {
217        // We have to be careful about duplicate nodes
218        if !self.dgraph.contains_node(id) {
219            self.dgraph.add_node(id);
220            assert!(!self.graph.contains_key(&id));
221            self.graph.insert(id, SeqHashSet::new());
222        }
223    }
224
225    fn insert_edge(&mut self, id1: ItemId) {
226        let id0 = self.current_id.unwrap();
227        self.insert_node(id1);
228        if !self.dgraph.contains_edge(id0, id1) {
229            self.dgraph.add_edge(id0, id1, ());
230            self.graph.get_mut(&id0).unwrap().insert(id1);
231        }
232    }
233}
234
235impl VisitAst for Deps {
236    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
237        self.insert_edge((*id).into());
238    }
239
240    fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
241        self.insert_edge((*id).into());
242    }
243
244    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
245        // If the impl is the impl this item belongs to, we ignore it
246        // TODO: this is not very satisfying but this is the only way we have of preventing
247        // mutually recursive groups between method impls and trait impls in the presence of
248        // associated types...
249        if let Some(impl_id) = &self.parent_trait_impl
250            && impl_id == id
251        {
252            return;
253        }
254        self.insert_edge((*id).into());
255    }
256
257    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
258        // If the trait is the trait this item belongs to, we ignore it. This is to avoid mutually
259        // recursive groups between e.g. traits decls and their globals. We treat methods
260        // specifically.
261        if let Some(trait_id) = &self.parent_trait_decl
262            && trait_id == id
263        {
264            return;
265        }
266        self.insert_edge((*id).into());
267    }
268
269    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
270        self.insert_edge((*id).into());
271    }
272
273    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
274        // Don't look inside because trait impls contain their own id in their name.
275        Continue(())
276    }
277    fn visit_item_source(&mut self, _: &ItemSource) -> ControlFlow<Self::Break> {
278        // Don't look inside to avoid recording a dependency from a method impl to the impl block
279        // it belongs to.
280        Continue(())
281    }
282}
283
284impl Deps {
285    fn fmt_with_ctx(&self, ctx: &FmtCtx<'_>) -> String {
286        self.dgraph
287            .nodes()
288            .map(|node| {
289                let edges = self
290                    .dgraph
291                    .edges(node)
292                    .map(|e| format!("\n  {}", e.1.with_ctx(ctx)))
293                    .collect::<Vec<String>>()
294                    .join(",");
295
296                format!("{} -> [{}\n]", node.with_ctx(ctx), edges)
297            })
298            .format(",\n")
299            .to_string()
300    }
301}
302
303fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
304    // The order we explore the items in will dictate the final order. We do the following: std
305    // items, then items from foreign crates (sorted by crate name), then local items. Within a
306    // crate, we sort by file then by source order.
307    let mut sorted_items = ctx.translated.all_items().collect_vec();
308    // Pre-sort files to avoid costly string comparisons. Maps file ids to an index that reflects
309    // ordering on the crates (with `core` and `std` sorted first) and file names.
310    let sorted_file_ids: IndexMap<FileId, usize> = ctx
311        .translated
312        .files
313        .all_indices()
314        .sorted_by_key(|&file_id| {
315            let file = &ctx.translated.files[file_id];
316            let is_std = file.crate_name == "std" || file.crate_name == "core";
317            (!is_std, &file.crate_name, &file.name)
318        })
319        .enumerate()
320        .sorted_by_key(|(_i, file_id)| *file_id)
321        .map(|(i, _file_id)| i)
322        .collect();
323    assert_eq!(ctx.translated.files.len(), sorted_file_ids.slot_count());
324    sorted_items.sort_by_key(|item| {
325        let item_meta = item.item_meta();
326        let span = item_meta.span.data;
327        let file_name_order = sorted_file_ids.get(span.file_id);
328        (item_meta.is_local, file_name_order, span.beg)
329    });
330
331    let mut graph = Deps::new();
332    for item in sorted_items {
333        let id = item.id();
334        graph.set_current_id(ctx, id);
335        match item {
336            ItemRef::Type(..) | ItemRef::TraitImpl(..) | ItemRef::Global(..) => {
337                let _ = item.drive(&mut graph);
338            }
339            ItemRef::Fun(d) => {
340                let FunDecl {
341                    def_id: _,
342                    item_meta: _,
343                    generics,
344                    signature,
345                    src,
346                    is_global_initializer: _,
347                    body,
348                } = d;
349                // Skip `d.is_global_initializer` to avoid incorrect mutual dependencies.
350                // TODO: add `is_global_initializer` to `ItemKind`.
351                let _ = generics.drive(&mut graph);
352                let _ = signature.drive(&mut graph);
353                let _ = body.drive(&mut graph);
354                // FIXME(#514): A method declaration depends on its declaring trait because of its
355                // `Self` clause. While the clause is implicit, we make sure to record the
356                // dependency manually.
357                if let ItemSource::TraitDecl { trait_ref, .. } = src {
358                    graph.insert_edge(trait_ref.id.into());
359                }
360            }
361            ItemRef::TraitDecl(d) => {
362                let TraitDecl {
363                    def_id: _,
364                    item_meta: _,
365                    generics,
366                    implied_clauses: parent_clauses,
367                    consts,
368                    types,
369                    methods,
370                    vtable,
371                } = d;
372                // Visit the traits referenced in the generics
373                let _ = generics.drive(&mut graph);
374
375                // Visit the parent clauses
376                let _ = parent_clauses.drive(&mut graph);
377
378                // Visit the items
379                let _ = types.drive(&mut graph);
380                let _ = vtable.drive(&mut graph);
381
382                // We consider that a trait decl only contains the function/constant signatures.
383                // Therefore we don't explore the default const/method ids.
384                for assoc_const in consts {
385                    let TraitAssocConst {
386                        name: _,
387                        ty,
388                        default,
389                    } = assoc_const;
390                    let _ = ty.drive(&mut graph);
391                    if let Some(gref) = default {
392                        let _ = gref.generics.drive(&mut graph);
393                    }
394                }
395                for bound_method in methods {
396                    let id = bound_method.skip_binder.item.id;
397                    let _ = bound_method.params.drive(&mut graph);
398                    if let Some(decl) = ctx.translated.fun_decls.get(id) {
399                        let _ = decl.signature.drive(&mut graph);
400                    }
401                }
402            }
403        }
404        graph.unset_current_id();
405    }
406    graph
407}
408
409fn group_declarations_from_scc(
410    _ctx: &TransformCtx,
411    graph: Deps,
412    reordered_sccs: SCCs<ItemId>,
413) -> DeclarationsGroups {
414    let reordered_sccs = &reordered_sccs.sccs;
415    let mut reordered_decls: DeclarationsGroups = Vec::new();
416
417    // Iterate over the SCC ids in the proper order
418    for scc in reordered_sccs.iter() {
419        if scc.is_empty() {
420            // This can happen if we failed to translate the item in this group.
421            continue;
422        }
423
424        // Note that the length of an SCC should be at least 1.
425        let mut it = scc.iter();
426        let id0 = *it.next().unwrap();
427        let decl = graph.graph.get(&id0).unwrap();
428
429        // If an SCC has length one, the declaration may be simply recursive:
430        // we determine whether it is the case by checking if the def id is in
431        // its own set of dependencies.
432        let is_mutually_recursive = scc.len() > 1;
433        let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
434        let is_rec = is_mutually_recursive || is_simply_recursive;
435
436        let all_same_kind = scc
437            .iter()
438            .all(|id| id0.variant_index_arity() == id.variant_index_arity());
439        let ids = scc.iter().copied();
440        let group: DeclarationGroup = match id0 {
441            _ if !all_same_kind => {
442                DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
443            }
444            ItemId::Type(_) => DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids)),
445            ItemId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
446            ItemId::Global(_) => {
447                DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
448            }
449            ItemId::TraitDecl(_) => {
450                let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
451                // Trait declarations often refer to `Self`, like below,
452                // which means they are often considered as recursive by our
453                // analysis. TODO: do something more precise. What is important
454                // is that we never use the "whole" self clause as argument,
455                // but rather projections over the self clause (like `<Self as Foo>::u`,
456                // in the declaration for `Foo`).
457                if gr.len() == 1 {
458                    DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
459                } else {
460                    DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
461                }
462            }
463            ItemId::TraitImpl(_) => {
464                DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
465            }
466        };
467
468        reordered_decls.push(group);
469    }
470    reordered_decls
471}
472
473fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
474    trace!();
475
476    // Step 1: explore the declarations to build the graph
477    let graph = compute_declarations_graph(ctx);
478    trace!("Graph:\n{}\n", graph.fmt_with_ctx(&ctx.into_fmt()));
479
480    // Step 2: Apply Tarjan's SCC (Strongly Connected Components) algorithm
481    let sccs = tarjan_scc(&graph.dgraph);
482
483    // Step 3: Reorder the declarations in an order as close as possible to the one
484    // given by the user. To be more precise, if we don't need to move
485    // definitions, the order in which we generate the declarations should
486    // be the same as the one in which the user wrote them.
487    // Remark: the [get_id_dependencies] function will be called once per id, meaning
488    // it is ok if it is not very efficient and clones values.
489    let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
490    let all_ids: Vec<ItemId> = graph
491        .graph
492        .keys()
493        .copied()
494        // Don't list ids that weren't translated.
495        .filter(|id| ctx.translated.get_item(*id).is_some())
496        .collect();
497    let reordered_sccs = reorder_sccs::<ItemId>(get_id_dependencies, &all_ids, &sccs);
498
499    // Finally, generate the list of declarations
500    let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
501
502    trace!("{:?}", reordered_decls);
503    reordered_decls
504}
505
506pub struct Transform;
507impl TransformPass for Transform {
508    fn transform_ctx(&self, ctx: &mut TransformCtx) {
509        let reordered_decls = compute_reordered_decls(&ctx);
510        ctx.translated.ordered_decls = Some(reordered_decls);
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    #[test]
517    fn test_reorder_sccs1() {
518        let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
519        let ids = vec![0, 1, 2, 3, 4, 5];
520
521        let get_deps = &|x| match x {
522            0 => vec![3],
523            1 => vec![0, 3],
524            _ => vec![],
525        };
526        let reordered = super::reorder_sccs(get_deps, &ids, &sccs);
527
528        assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
529    }
530}