charon_lib/transform/
reorder_decls.rs

1//! Compute an ordering on declarations that:
2//! - Detects mutually-recursive groups;
3//! - Always orders an item before any of its uses (except for recursive cases);
4//! - Otherwise keeps a stable order.
5//!
6//! Aeneas needs this because proof assistant languages are sensitive to declaration order and need
7//! to be explicit about mutual recursion. This should come useful for translation to any other
8//! language with these properties.
9use crate::common::*;
10use crate::formatter::{FmtCtx, IntoFormatter};
11use crate::graphs::*;
12use crate::pretty::FmtWithCtx;
13use crate::transform::TransformCtx;
14use crate::ullbc_ast::*;
15use derive_generic_visitor::*;
16use indexmap::{IndexMap, IndexSet};
17use itertools::Itertools;
18use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
19use petgraph::algo::tarjan_scc;
20use petgraph::graphmap::DiGraphMap;
21use serde::{Deserialize, Serialize};
22use std::fmt::{Debug, Display, Error};
23use std::vec::Vec;
24
25use super::ctx::TransformPass;
26
27/// A (group of) top-level declaration(s), properly reordered.
28/// "G" stands for "generic"
29#[derive(
30    Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
31)]
32#[charon::variants_suffix("Group")]
33pub enum GDeclarationGroup<Id> {
34    /// A non-recursive declaration
35    NonRec(Id),
36    /// A (group of mutually) recursive declaration(s)
37    Rec(Vec<Id>),
38}
39
40/// A (group of) top-level declaration(s), properly reordered.
41#[derive(
42    Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
43)]
44#[charon::variants_suffix("Group")]
45pub enum DeclarationGroup {
46    /// A type declaration group
47    Type(GDeclarationGroup<TypeDeclId>),
48    /// A function declaration group
49    Fun(GDeclarationGroup<FunDeclId>),
50    /// A global declaration group
51    Global(GDeclarationGroup<GlobalDeclId>),
52    ///
53    TraitDecl(GDeclarationGroup<TraitDeclId>),
54    ///
55    TraitImpl(GDeclarationGroup<TraitImplId>),
56    /// Anything that doesn't fit into these categories.
57    Mixed(GDeclarationGroup<AnyTransId>),
58}
59
60impl<Id: Copy> GDeclarationGroup<Id> {
61    pub fn get_ids(&self) -> &[Id] {
62        use GDeclarationGroup::*;
63        match self {
64            NonRec(id) => std::slice::from_ref(id),
65            Rec(ids) => ids.as_slice(),
66        }
67    }
68
69    pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
70    where
71        Id: Into<AnyTransId>,
72    {
73        self.get_ids().iter().copied().map(|id| id.into()).collect()
74    }
75
76    fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
77    where
78        Id: TryFrom<AnyTransId>,
79        Id::Error: Debug,
80    {
81        let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
82        if is_rec {
83            GDeclarationGroup::Rec(gr)
84        } else {
85            assert!(gr.len() == 1);
86            GDeclarationGroup::NonRec(gr[0])
87        }
88    }
89
90    fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
91    where
92        Id: Into<AnyTransId>,
93    {
94        match self {
95            GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
96            GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
97        }
98    }
99}
100
101impl DeclarationGroup {
102    pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
103        use DeclarationGroup::*;
104        match self {
105            Type(gr) => gr.to_mixed(),
106            Fun(gr) => gr.to_mixed(),
107            Global(gr) => gr.to_mixed(),
108            TraitDecl(gr) => gr.to_mixed(),
109            TraitImpl(gr) => gr.to_mixed(),
110            Mixed(gr) => gr.clone(),
111        }
112    }
113
114    pub fn get_ids(&self) -> Vec<AnyTransId> {
115        use DeclarationGroup::*;
116        match self {
117            Type(gr) => gr.get_any_trans_ids(),
118            Fun(gr) => gr.get_any_trans_ids(),
119            Global(gr) => gr.get_any_trans_ids(),
120            TraitDecl(gr) => gr.get_any_trans_ids(),
121            TraitImpl(gr) => gr.get_any_trans_ids(),
122            Mixed(gr) => gr.get_any_trans_ids(),
123        }
124    }
125}
126
127#[derive(Clone, Copy)]
128pub struct DeclInfo {
129    pub is_transparent: bool,
130}
131
132pub type DeclarationsGroups = Vec<DeclarationGroup>;
133
134impl<Id: Display> Display for GDeclarationGroup<Id> {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
136        match self {
137            GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
138            GDeclarationGroup::Rec(ids) => {
139                write!(
140                    f,
141                    "rec: {}",
142                    pretty_display_list(|id| format!("    {id}"), ids)
143                )
144            }
145        }
146    }
147}
148
149impl Display for DeclarationGroup {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
151        match self {
152            DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
153            DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
154            DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
155            DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
156            DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
157            DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
158        }
159    }
160}
161
162#[derive(Visitor)]
163pub struct Deps {
164    dgraph: DiGraphMap<AnyTransId, ()>,
165    // Want to make sure we remember the order of insertion
166    graph: IndexMap<AnyTransId, IndexSet<AnyTransId>>,
167    // We use this when computing the graph
168    current_id: Option<AnyTransId>,
169    // We use this to track the trait impl block the current item belongs to
170    // (if relevant).
171    //
172    // We use this to ignore the references to the parent impl block.
173    //
174    // If we don't do so, when computing our dependency graph we end up with
175    // mutually recursive trait impl blocks/trait method impls in the presence
176    // of associated types (the deepest reason is that we don't normalize the
177    // types we query from rustc when translating the types from function
178    // signatures - we avoid doing so because as of now it makes resolving
179    // the trait params harder: if we get normalized types, we have to
180    // implement a normalizer on our side to make sure we correctly match
181    // types...).
182    //
183    //
184    // For instance, the problem happens if in Rust we have:
185    // ```text
186    // pub trait WithConstTy {
187    //     type W;
188    //     fn f(x: &mut Self::W);
189    // }
190    //
191    // impl WithConstTy for bool {
192    //     type W = u64;
193    //     fn f(_: &mut Self::W) {}
194    // }
195    // ```
196    //
197    // In LLBC we get:
198    //
199    // ```text
200    // impl traits::Bool::0 : traits::WithConstTy<bool>
201    // {
202    //     type W = u64 with []
203    //     fn f = traits::Bool::0::f
204    // }
205    //
206    // fn traits::Bool::0::f<@R0>(@1: &@R0 mut (traits::Bool::0::W)) { .. }
207    // //                                       ^^^^^^^^^^^^^^^
208    // //                                    refers to the trait impl
209    // ```
210    parent_trait_impl: Option<TraitImplId>,
211    parent_trait_decl: Option<TraitDeclId>,
212}
213
214impl Deps {
215    fn new() -> Self {
216        Deps {
217            dgraph: DiGraphMap::new(),
218            graph: IndexMap::new(),
219            current_id: None,
220            parent_trait_impl: None,
221            parent_trait_decl: None,
222        }
223    }
224
225    fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
226        match kind {
227            ItemKind::TraitDecl { trait_ref, .. } => self.parent_trait_decl = Some(trait_ref.id),
228            ItemKind::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.id),
229            _ => {}
230        }
231    }
232    fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
233        self.insert_node(id);
234        self.current_id = Some(id);
235
236        // Add the id of the impl/trait this item belongs to, if necessary
237        use AnyTransId::*;
238        match id {
239            TraitDecl(_) | TraitImpl(_) | Type(_) => (),
240            Global(id) => {
241                if let Some(decl) = ctx.translated.global_decls.get(id) {
242                    self.set_impl_or_trait_id(&decl.kind);
243                }
244            }
245            Fun(id) => {
246                if let Some(decl) = ctx.translated.fun_decls.get(id) {
247                    self.set_impl_or_trait_id(&decl.kind);
248                }
249            }
250        }
251    }
252
253    fn unset_current_id(&mut self) {
254        self.current_id = None;
255        self.parent_trait_impl = None;
256        self.parent_trait_decl = None;
257    }
258
259    fn insert_node(&mut self, id: AnyTransId) {
260        // We have to be careful about duplicate nodes
261        if !self.dgraph.contains_node(id) {
262            self.dgraph.add_node(id);
263            assert!(!self.graph.contains_key(&id));
264            self.graph.insert(id, IndexSet::new());
265        }
266    }
267
268    fn insert_edge(&mut self, id1: AnyTransId) {
269        let id0 = self.current_id.unwrap();
270        self.insert_node(id1);
271        if !self.dgraph.contains_edge(id0, id1) {
272            self.dgraph.add_edge(id0, id1, ());
273            self.graph.get_mut(&id0).unwrap().insert(id1);
274        }
275    }
276}
277
278impl VisitAst for Deps {
279    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
280        self.insert_edge((*id).into());
281    }
282
283    fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
284        self.insert_edge((*id).into());
285    }
286
287    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
288        // If the impl is the impl this item belongs to, we ignore it
289        // TODO: this is not very satisfying but this is the only way we have of preventing
290        // mutually recursive groups between method impls and trait impls in the presence of
291        // associated types...
292        if let Some(impl_id) = &self.parent_trait_impl
293            && impl_id == id
294        {
295            return;
296        }
297        self.insert_edge((*id).into());
298    }
299
300    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
301        // If the trait is the trait this item belongs to, we ignore it. This is to avoid mutually
302        // recursive groups between e.g. traits decls and their globals. We treat methods
303        // specifically.
304        if let Some(trait_id) = &self.parent_trait_decl
305            && trait_id == id
306        {
307            return;
308        }
309        self.insert_edge((*id).into());
310    }
311
312    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
313        self.insert_edge((*id).into());
314    }
315
316    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
317        // Don't look inside because trait impls contain their own id in their name.
318        Continue(())
319    }
320    fn visit_item_kind(&mut self, _: &ItemKind) -> ControlFlow<Self::Break> {
321        // Don't look inside to avoid recording a dependency from a method impl to the impl block
322        // it belongs to.
323        Continue(())
324    }
325}
326
327impl Deps {
328    fn fmt_with_ctx(&self, ctx: &FmtCtx<'_>) -> String {
329        self.dgraph
330            .nodes()
331            .map(|node| {
332                let edges = self
333                    .dgraph
334                    .edges(node)
335                    .map(|e| format!("\n  {}", e.1.with_ctx(ctx)))
336                    .collect::<Vec<String>>()
337                    .join(",");
338
339                format!("{} -> [{}\n]", node.with_ctx(ctx), edges)
340            })
341            .format(",\n")
342            .to_string()
343    }
344}
345
346fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
347    // The order we explore the items in will dictate the final order. We do the following: std
348    // items, then items from foreign crates (sorted by crate name), then local items. Within a
349    // crate, we sort by file then by source order.
350    let mut sorted_items = ctx.translated.all_items().collect_vec();
351    // Pre-sort files to avoid costly string comparisons. Maps file ids to an index that reflects
352    // ordering on the crates (with `core` and `std` sorted first) and file names.
353    let sorted_file_ids: Vector<FileId, usize> = ctx
354        .translated
355        .files
356        .all_indices()
357        .sorted_by_key(|&file_id| {
358            let file = &ctx.translated.files[file_id];
359            let is_std = file.crate_name == "std" || file.crate_name == "core";
360            (!is_std, &file.crate_name, &file.name)
361        })
362        .enumerate()
363        .sorted_by_key(|(_i, file_id)| *file_id)
364        .map(|(i, _file_id)| i)
365        .collect();
366    assert_eq!(
367        ctx.translated.files.slot_count(),
368        sorted_file_ids.slot_count()
369    );
370    sorted_items.sort_by_key(|item| {
371        let item_meta = item.item_meta();
372        let span = item_meta.span.span;
373        let file_name_order = sorted_file_ids[span.file_id];
374        (item_meta.is_local, file_name_order, span.beg)
375    });
376
377    let mut graph = Deps::new();
378    for item in sorted_items {
379        let id = item.id();
380        graph.set_current_id(ctx, id);
381        match item {
382            AnyTransItem::Type(..) | AnyTransItem::TraitImpl(..) | AnyTransItem::Global(..) => {
383                let _ = item.drive(&mut graph);
384            }
385            AnyTransItem::Fun(d) => {
386                // Skip `d.is_global_initializer` to avoid incorrect mutual dependencies.
387                // TODO: add `is_global_initializer` to `ItemKind`.
388                let _ = d.signature.drive(&mut graph);
389                let _ = d.body.drive(&mut graph);
390                // FIXME(#514): A method declaration depends on its declaring trait because of its
391                // `Self` clause. While the clause is implicit, we make sure to record the
392                // dependency manually.
393                if let ItemKind::TraitDecl { trait_ref, .. } = &d.kind {
394                    graph.insert_edge(trait_ref.id.into());
395                }
396            }
397            AnyTransItem::TraitDecl(d) => {
398                let TraitDecl {
399                    def_id: _,
400                    item_meta: _,
401                    generics,
402                    parent_clauses,
403                    consts,
404                    const_defaults,
405                    types,
406                    type_defaults,
407                    type_clauses,
408                    methods,
409                } = d;
410                // Visit the traits referenced in the generics
411                let _ = generics.drive(&mut graph);
412
413                // Visit the parent clauses
414                let _ = parent_clauses.drive(&mut graph);
415                assert!(type_clauses.is_empty());
416
417                // Visit the items
418                let _ = consts.drive(&mut graph);
419                let _ = types.drive(&mut graph);
420                let _ = type_defaults.drive(&mut graph);
421
422                // We consider that a trait decl only contains the function/constant signatures.
423                // Therefore we don't explore the default const/method ids.
424                for (_name, gref) in const_defaults {
425                    let _ = gref.generics.drive(&mut graph);
426                }
427                for (_, bound_fn) in methods {
428                    let id = bound_fn.skip_binder.id;
429                    let _ = bound_fn.params.drive(&mut graph);
430                    if let Some(decl) = ctx.translated.fun_decls.get(id) {
431                        let _ = decl.signature.drive(&mut graph);
432                    }
433                }
434            }
435        }
436        graph.unset_current_id();
437    }
438    graph
439}
440
441fn group_declarations_from_scc(
442    _ctx: &TransformCtx,
443    graph: Deps,
444    reordered_sccs: SCCs<AnyTransId>,
445) -> DeclarationsGroups {
446    let reordered_sccs = &reordered_sccs.sccs;
447    let mut reordered_decls: DeclarationsGroups = Vec::new();
448
449    // Iterate over the SCC ids in the proper order
450    for scc in reordered_sccs.iter() {
451        if scc.is_empty() {
452            // This can happen if we failed to translate the item in this group.
453            continue;
454        }
455
456        // Note that the length of an SCC should be at least 1.
457        let mut it = scc.iter();
458        let id0 = *it.next().unwrap();
459        let decl = graph.graph.get(&id0).unwrap();
460
461        // If an SCC has length one, the declaration may be simply recursive:
462        // we determine whether it is the case by checking if the def id is in
463        // its own set of dependencies.
464        let is_mutually_recursive = scc.len() > 1;
465        let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
466        let is_rec = is_mutually_recursive || is_simply_recursive;
467
468        let all_same_kind = scc
469            .iter()
470            .all(|id| id0.variant_index_arity() == id.variant_index_arity());
471        let ids = scc.iter().copied();
472        let group: DeclarationGroup = match id0 {
473            _ if !all_same_kind => {
474                DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
475            }
476            AnyTransId::Type(_) => {
477                DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
478            }
479            AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
480            AnyTransId::Global(_) => {
481                DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
482            }
483            AnyTransId::TraitDecl(_) => {
484                let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
485                // Trait declarations often refer to `Self`, like below,
486                // which means they are often considered as recursive by our
487                // analysis. TODO: do something more precise. What is important
488                // is that we never use the "whole" self clause as argument,
489                // but rather projections over the self clause (like `<Self as Foo>::u`,
490                // in the declaration for `Foo`).
491                if gr.len() == 1 {
492                    DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
493                } else {
494                    DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
495                }
496            }
497            AnyTransId::TraitImpl(_) => {
498                DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
499            }
500        };
501
502        reordered_decls.push(group);
503    }
504    reordered_decls
505}
506
507fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
508    trace!();
509
510    // Step 1: explore the declarations to build the graph
511    let graph = compute_declarations_graph(ctx);
512    trace!("Graph:\n{}\n", graph.fmt_with_ctx(&ctx.into_fmt()));
513
514    // Step 2: Apply Tarjan's SCC (Strongly Connected Components) algorithm
515    let sccs = tarjan_scc(&graph.dgraph);
516
517    // Step 3: Reorder the declarations in an order as close as possible to the one
518    // given by the user. To be more precise, if we don't need to move
519    // definitions, the order in which we generate the declarations should
520    // be the same as the one in which the user wrote them.
521    // Remark: the [get_id_dependencies] function will be called once per id, meaning
522    // it is ok if it is not very efficient and clones values.
523    let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
524    let all_ids: Vec<AnyTransId> = graph
525        .graph
526        .keys()
527        .copied()
528        // Don't list ids that weren't translated.
529        .filter(|id| ctx.translated.get_item(*id).is_some())
530        .collect();
531    let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
532
533    // Finally, generate the list of declarations
534    let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
535
536    trace!("{:?}", reordered_decls);
537    reordered_decls
538}
539
540pub struct Transform;
541impl TransformPass for Transform {
542    fn transform_ctx(&self, ctx: &mut TransformCtx) {
543        let reordered_decls = compute_reordered_decls(&ctx);
544        ctx.translated.ordered_decls = Some(reordered_decls);
545    }
546}
547
548#[cfg(test)]
549mod tests {
550    #[test]
551    fn test_reorder_sccs1() {
552        use std::collections::BTreeSet as OrdSet;
553        let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
554        let ids = vec![0, 1, 2, 3, 4, 5];
555
556        let get_deps = &|x| match x {
557            0 => vec![3],
558            1 => vec![0, 3],
559            _ => vec![],
560        };
561        let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
562
563        assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
564        assert!(reordered.scc_deps[0] == OrdSet::from([]));
565        assert!(reordered.scc_deps[1] == OrdSet::from([0]));
566        assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
567    }
568}