charon_lib/transform/
reorder_decls.rs

1//! Compute an ordering on declarations that:
2//! - Detects mutually-recursive groups;
3//! - Always orders an item before any of its uses (except for recursive cases);
4//! - Otherwise keeps a stable order.
5//!
6//! Aeneas needs this because proof assistant languages are sensitive to declaration order and need
7//! to be explicit about mutual recursion. This should come useful for translation to any other
8//! language with these properties.
9use crate::common::*;
10use crate::formatter::{FmtCtx, IntoFormatter};
11use crate::graphs::*;
12use crate::pretty::FmtWithCtx;
13use crate::transform::TransformCtx;
14use crate::ullbc_ast::*;
15use derive_generic_visitor::*;
16use indexmap::{IndexMap, IndexSet};
17use itertools::Itertools;
18use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
19use petgraph::algo::tarjan_scc;
20use petgraph::graphmap::DiGraphMap;
21use serde::{Deserialize, Serialize};
22use std::fmt::{Debug, Display, Error};
23use std::vec::Vec;
24
25use super::ctx::TransformPass;
26
27/// A (group of) top-level declaration(s), properly reordered.
28/// "G" stands for "generic"
29#[derive(
30    Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
31)]
32#[charon::variants_suffix("Group")]
33pub enum GDeclarationGroup<Id> {
34    /// A non-recursive declaration
35    NonRec(Id),
36    /// A (group of mutually) recursive declaration(s)
37    Rec(Vec<Id>),
38}
39
40/// A (group of) top-level declaration(s), properly reordered.
41#[derive(
42    Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
43)]
44#[charon::variants_suffix("Group")]
45pub enum DeclarationGroup {
46    /// A type declaration group
47    Type(GDeclarationGroup<TypeDeclId>),
48    /// A function declaration group
49    Fun(GDeclarationGroup<FunDeclId>),
50    /// A global declaration group
51    Global(GDeclarationGroup<GlobalDeclId>),
52    ///
53    TraitDecl(GDeclarationGroup<TraitDeclId>),
54    ///
55    TraitImpl(GDeclarationGroup<TraitImplId>),
56    /// Anything that doesn't fit into these categories.
57    Mixed(GDeclarationGroup<AnyTransId>),
58}
59
60impl<Id: Copy> GDeclarationGroup<Id> {
61    pub fn get_ids(&self) -> &[Id] {
62        use GDeclarationGroup::*;
63        match self {
64            NonRec(id) => std::slice::from_ref(id),
65            Rec(ids) => ids.as_slice(),
66        }
67    }
68
69    pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
70    where
71        Id: Into<AnyTransId>,
72    {
73        self.get_ids().iter().copied().map(|id| id.into()).collect()
74    }
75
76    fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
77    where
78        Id: TryFrom<AnyTransId>,
79        Id::Error: Debug,
80    {
81        let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
82        if is_rec {
83            GDeclarationGroup::Rec(gr)
84        } else {
85            assert!(gr.len() == 1);
86            GDeclarationGroup::NonRec(gr[0])
87        }
88    }
89
90    fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
91    where
92        Id: Into<AnyTransId>,
93    {
94        match self {
95            GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
96            GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
97        }
98    }
99}
100
101impl DeclarationGroup {
102    pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
103        use DeclarationGroup::*;
104        match self {
105            Type(gr) => gr.to_mixed(),
106            Fun(gr) => gr.to_mixed(),
107            Global(gr) => gr.to_mixed(),
108            TraitDecl(gr) => gr.to_mixed(),
109            TraitImpl(gr) => gr.to_mixed(),
110            Mixed(gr) => gr.clone(),
111        }
112    }
113
114    pub fn get_ids(&self) -> Vec<AnyTransId> {
115        use DeclarationGroup::*;
116        match self {
117            Type(gr) => gr.get_any_trans_ids(),
118            Fun(gr) => gr.get_any_trans_ids(),
119            Global(gr) => gr.get_any_trans_ids(),
120            TraitDecl(gr) => gr.get_any_trans_ids(),
121            TraitImpl(gr) => gr.get_any_trans_ids(),
122            Mixed(gr) => gr.get_any_trans_ids(),
123        }
124    }
125}
126
127#[derive(Clone, Copy)]
128pub struct DeclInfo {
129    pub is_transparent: bool,
130}
131
132pub type DeclarationsGroups = Vec<DeclarationGroup>;
133
134impl<Id: Display> Display for GDeclarationGroup<Id> {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
136        match self {
137            GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
138            GDeclarationGroup::Rec(ids) => {
139                write!(
140                    f,
141                    "rec: {}",
142                    pretty_display_list(|id| format!("    {id}"), ids)
143                )
144            }
145        }
146    }
147}
148
149impl Display for DeclarationGroup {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
151        match self {
152            DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
153            DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
154            DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
155            DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
156            DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
157            DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
158        }
159    }
160}
161
162#[derive(Visitor)]
163pub struct Deps {
164    dgraph: DiGraphMap<AnyTransId, ()>,
165    // Want to make sure we remember the order of insertion
166    graph: IndexMap<AnyTransId, IndexSet<AnyTransId>>,
167    // We use this when computing the graph
168    current_id: Option<AnyTransId>,
169    // We use this to track the trait impl block the current item belongs to
170    // (if relevant).
171    //
172    // We use this to ignore the references to the parent impl block.
173    //
174    // If we don't do so, when computing our dependency graph we end up with
175    // mutually recursive trait impl blocks/trait method impls in the presence
176    // of associated types (the deepest reason is that we don't normalize the
177    // types we query from rustc when translating the types from function
178    // signatures - we avoid doing so because as of now it makes resolving
179    // the trait params harder: if we get normalized types, we have to
180    // implement a normalizer on our side to make sure we correctly match
181    // types...).
182    //
183    //
184    // For instance, the problem happens if in Rust we have:
185    // ```text
186    // pub trait WithConstTy {
187    //     type W;
188    //     fn f(x: &mut Self::W);
189    // }
190    //
191    // impl WithConstTy for bool {
192    //     type W = u64;
193    //     fn f(_: &mut Self::W) {}
194    // }
195    // ```
196    //
197    // In LLBC we get:
198    //
199    // ```text
200    // impl traits::Bool::0 : traits::WithConstTy<bool>
201    // {
202    //     type W = u64 with []
203    //     fn f = traits::Bool::0::f
204    // }
205    //
206    // fn traits::Bool::0::f<@R0>(@1: &@R0 mut (traits::Bool::0::W)) { .. }
207    // //                                       ^^^^^^^^^^^^^^^
208    // //                                    refers to the trait impl
209    // ```
210    parent_trait_impl: Option<TraitImplId>,
211    parent_trait_decl: Option<TraitDeclId>,
212}
213
214impl Deps {
215    fn new() -> Self {
216        Deps {
217            dgraph: DiGraphMap::new(),
218            graph: IndexMap::new(),
219            current_id: None,
220            parent_trait_impl: None,
221            parent_trait_decl: None,
222        }
223    }
224
225    fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
226        match kind {
227            ItemKind::TraitDecl { trait_ref, .. } => {
228                self.parent_trait_decl = Some(trait_ref.trait_id)
229            }
230            ItemKind::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.impl_id),
231            _ => {}
232        }
233    }
234    fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
235        self.insert_node(id);
236        self.current_id = Some(id);
237
238        // Add the id of the impl/trait this item belongs to, if necessary
239        use AnyTransId::*;
240        match id {
241            TraitDecl(_) | TraitImpl(_) | Type(_) => (),
242            Global(id) => {
243                if let Some(decl) = ctx.translated.global_decls.get(id) {
244                    self.set_impl_or_trait_id(&decl.kind);
245                }
246            }
247            Fun(id) => {
248                if let Some(decl) = ctx.translated.fun_decls.get(id) {
249                    self.set_impl_or_trait_id(&decl.kind);
250                }
251            }
252        }
253    }
254
255    fn unset_current_id(&mut self) {
256        self.current_id = None;
257        self.parent_trait_impl = None;
258        self.parent_trait_decl = None;
259    }
260
261    fn insert_node(&mut self, id: AnyTransId) {
262        // We have to be careful about duplicate nodes
263        if !self.dgraph.contains_node(id) {
264            self.dgraph.add_node(id);
265            assert!(!self.graph.contains_key(&id));
266            self.graph.insert(id, IndexSet::new());
267        }
268    }
269
270    fn insert_edge(&mut self, id1: AnyTransId) {
271        let id0 = self.current_id.unwrap();
272        self.insert_node(id1);
273        if !self.dgraph.contains_edge(id0, id1) {
274            self.dgraph.add_edge(id0, id1, ());
275            self.graph.get_mut(&id0).unwrap().insert(id1);
276        }
277    }
278}
279
280impl VisitAst for Deps {
281    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
282        self.insert_edge((*id).into());
283    }
284
285    fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
286        self.insert_edge((*id).into());
287    }
288
289    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
290        // If the impl is the impl this item belongs to, we ignore it
291        // TODO: this is not very satisfying but this is the only way we have of preventing
292        // mutually recursive groups between method impls and trait impls in the presence of
293        // associated types...
294        if let Some(impl_id) = &self.parent_trait_impl
295            && impl_id == id
296        {
297            return;
298        }
299        self.insert_edge((*id).into());
300    }
301
302    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
303        // If the trait is the trait this item belongs to, we ignore it. This is to avoid mutually
304        // recursive groups between e.g. traits decls and their globals. We treat methods
305        // specifically.
306        if let Some(trait_id) = &self.parent_trait_decl
307            && trait_id == id
308        {
309            return;
310        }
311        self.insert_edge((*id).into());
312    }
313
314    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
315        self.insert_edge((*id).into());
316    }
317
318    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
319        // Don't look inside because trait impls contain their own id in their name.
320        Continue(())
321    }
322    fn visit_item_kind(&mut self, _: &ItemKind) -> ControlFlow<Self::Break> {
323        // Don't look inside to avoid recording a dependency from a method impl to the impl block
324        // it belongs to.
325        Continue(())
326    }
327}
328
329impl Deps {
330    fn fmt_with_ctx(&self, ctx: &FmtCtx<'_>) -> String {
331        self.dgraph
332            .nodes()
333            .map(|node| {
334                let edges = self
335                    .dgraph
336                    .edges(node)
337                    .map(|e| format!("\n  {}", e.1.with_ctx(ctx)))
338                    .collect::<Vec<String>>()
339                    .join(",");
340
341                format!("{} -> [{}\n]", node.with_ctx(ctx), edges)
342            })
343            .format(",\n")
344            .to_string()
345    }
346}
347
348fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
349    let mut graph = Deps::new();
350    for (id, item) in ctx.translated.all_items_with_ids() {
351        graph.set_current_id(ctx, id);
352        match item {
353            AnyTransItem::Type(..) | AnyTransItem::TraitImpl(..) | AnyTransItem::Global(..) => {
354                let _ = item.drive(&mut graph);
355            }
356            AnyTransItem::Fun(d) => {
357                // Skip `d.is_global_initializer` to avoid incorrect mutual dependencies.
358                // TODO: add `is_global_initializer` to `ItemKind`.
359                let _ = d.signature.drive(&mut graph);
360                let _ = d.body.drive(&mut graph);
361                // FIXME(#514): A method declaration depends on its declaring trait because of its
362                // `Self` clause. While the clause is implicit, we make sure to record the
363                // dependency manually.
364                if let ItemKind::TraitDecl { trait_ref, .. } = &d.kind {
365                    graph.insert_edge(trait_ref.trait_id.into());
366                }
367            }
368            AnyTransItem::TraitDecl(d) => {
369                let TraitDecl {
370                    def_id: _,
371                    item_meta: _,
372                    generics,
373                    parent_clauses,
374                    consts,
375                    const_defaults,
376                    types,
377                    type_defaults,
378                    type_clauses,
379                    methods,
380                } = d;
381                // Visit the traits referenced in the generics
382                let _ = generics.drive(&mut graph);
383
384                // Visit the parent clauses
385                let _ = parent_clauses.drive(&mut graph);
386                assert!(type_clauses.is_empty());
387
388                // Visit the items
389                let _ = consts.drive(&mut graph);
390                let _ = types.drive(&mut graph);
391                let _ = type_defaults.drive(&mut graph);
392
393                // We consider that a trait decl only contains the function/constant signatures.
394                // Therefore we don't explore the default const/method ids.
395                for (_name, gref) in const_defaults {
396                    let _ = gref.generics.drive(&mut graph);
397                }
398                for (_, bound_fn) in methods {
399                    let id = bound_fn.skip_binder.id;
400                    let _ = bound_fn.params.drive(&mut graph);
401                    if let Some(decl) = ctx.translated.fun_decls.get(id) {
402                        let _ = decl.signature.drive(&mut graph);
403                    }
404                }
405            }
406        }
407        graph.unset_current_id();
408    }
409    graph
410}
411
412fn group_declarations_from_scc(
413    _ctx: &TransformCtx,
414    graph: Deps,
415    reordered_sccs: SCCs<AnyTransId>,
416) -> DeclarationsGroups {
417    let reordered_sccs = &reordered_sccs.sccs;
418    let mut reordered_decls: DeclarationsGroups = Vec::new();
419
420    // Iterate over the SCC ids in the proper order
421    for scc in reordered_sccs.iter() {
422        if scc.is_empty() {
423            // This can happen if we failed to translate the item in this group.
424            continue;
425        }
426
427        // Note that the length of an SCC should be at least 1.
428        let mut it = scc.iter();
429        let id0 = *it.next().unwrap();
430        let decl = graph.graph.get(&id0).unwrap();
431
432        // If an SCC has length one, the declaration may be simply recursive:
433        // we determine whether it is the case by checking if the def id is in
434        // its own set of dependencies.
435        let is_mutually_recursive = scc.len() > 1;
436        let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
437        let is_rec = is_mutually_recursive || is_simply_recursive;
438
439        let all_same_kind = scc
440            .iter()
441            .all(|id| id0.variant_index_arity() == id.variant_index_arity());
442        let ids = scc.iter().copied();
443        let group: DeclarationGroup = match id0 {
444            _ if !all_same_kind => {
445                DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
446            }
447            AnyTransId::Type(_) => {
448                DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
449            }
450            AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
451            AnyTransId::Global(_) => {
452                DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
453            }
454            AnyTransId::TraitDecl(_) => {
455                let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
456                // Trait declarations often refer to `Self`, like below,
457                // which means they are often considered as recursive by our
458                // analysis. TODO: do something more precise. What is important
459                // is that we never use the "whole" self clause as argument,
460                // but rather projections over the self clause (like `<Self as Foo>::u`,
461                // in the declaration for `Foo`).
462                if gr.len() == 1 {
463                    DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
464                } else {
465                    DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
466                }
467            }
468            AnyTransId::TraitImpl(_) => {
469                DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
470            }
471        };
472
473        reordered_decls.push(group);
474    }
475    reordered_decls
476}
477
478fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
479    trace!();
480
481    // Step 1: explore the declarations to build the graph
482    let graph = compute_declarations_graph(ctx);
483    trace!("Graph:\n{}\n", graph.fmt_with_ctx(&ctx.into_fmt()));
484
485    // Step 2: Apply Tarjan's SCC (Strongly Connected Components) algorithm
486    let sccs = tarjan_scc(&graph.dgraph);
487
488    // Step 3: Reorder the declarations in an order as close as possible to the one
489    // given by the user. To be more precise, if we don't need to move
490    // definitions, the order in which we generate the declarations should
491    // be the same as the one in which the user wrote them.
492    // Remark: the [get_id_dependencies] function will be called once per id, meaning
493    // it is ok if it is not very efficient and clones values.
494    let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
495    let all_ids: Vec<AnyTransId> = graph
496        .graph
497        .keys()
498        .copied()
499        // Don't list ids that weren't translated.
500        .filter(|id| ctx.translated.get_item(*id).is_some())
501        .collect();
502    let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
503
504    // Finally, generate the list of declarations
505    let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
506
507    trace!("{:?}", reordered_decls);
508    reordered_decls
509}
510
511pub struct Transform;
512impl TransformPass for Transform {
513    fn transform_ctx(&self, ctx: &mut TransformCtx) {
514        let reordered_decls = compute_reordered_decls(&ctx);
515        ctx.translated.ordered_decls = Some(reordered_decls);
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    #[test]
522    fn test_reorder_sccs1() {
523        use std::collections::BTreeSet as OrdSet;
524        let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
525        let ids = vec![0, 1, 2, 3, 4, 5];
526
527        let get_deps = &|x| match x {
528            0 => vec![3],
529            1 => vec![0, 3],
530            _ => vec![],
531        };
532        let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
533
534        assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
535        assert!(reordered.scc_deps[0] == OrdSet::from([]));
536        assert!(reordered.scc_deps[1] == OrdSet::from([0]));
537        assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
538    }
539}