charon_lib/transform/
reorder_decls.rs

1//! Compute an ordering on declarations that:
2//! - Detects mutually-recursive groups;
3//! - Always orders an item before any of its uses (except for recursive cases);
4//! - Otherwise keeps a stable order.
5//!
6//! Aeneas needs this because proof assistant languages are sensitive to declaration order and need
7//! to be explicit about mutual recursion. This should come useful for translation to any other
8//! language with these properties.
9use crate::common::*;
10use crate::formatter::{AstFormatter, IntoFormatter};
11use crate::graphs::*;
12use crate::transform::TransformCtx;
13use crate::ullbc_ast::*;
14use derive_generic_visitor::*;
15use indexmap::{IndexMap, IndexSet};
16use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
17use petgraph::algo::tarjan_scc;
18use petgraph::graphmap::DiGraphMap;
19use serde::{Deserialize, Serialize};
20use std::fmt::{Debug, Display, Error};
21use std::vec::Vec;
22
23use super::ctx::TransformPass;
24
25/// A (group of) top-level declaration(s), properly reordered.
26/// "G" stands for "generic"
27#[derive(
28    Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
29)]
30#[charon::variants_suffix("Group")]
31pub enum GDeclarationGroup<Id> {
32    /// A non-recursive declaration
33    NonRec(Id),
34    /// A (group of mutually) recursive declaration(s)
35    Rec(Vec<Id>),
36}
37
38/// A (group of) top-level declaration(s), properly reordered.
39#[derive(
40    Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
41)]
42#[charon::variants_suffix("Group")]
43pub enum DeclarationGroup {
44    /// A type declaration group
45    Type(GDeclarationGroup<TypeDeclId>),
46    /// A function declaration group
47    Fun(GDeclarationGroup<FunDeclId>),
48    /// A global declaration group
49    Global(GDeclarationGroup<GlobalDeclId>),
50    ///
51    TraitDecl(GDeclarationGroup<TraitDeclId>),
52    ///
53    TraitImpl(GDeclarationGroup<TraitImplId>),
54    /// Anything that doesn't fit into these categories.
55    Mixed(GDeclarationGroup<AnyTransId>),
56}
57
58impl<Id: Copy> GDeclarationGroup<Id> {
59    pub fn get_ids(&self) -> &[Id] {
60        use GDeclarationGroup::*;
61        match self {
62            NonRec(id) => std::slice::from_ref(id),
63            Rec(ids) => ids.as_slice(),
64        }
65    }
66
67    pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
68    where
69        Id: Into<AnyTransId>,
70    {
71        self.get_ids().iter().copied().map(|id| id.into()).collect()
72    }
73
74    fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
75    where
76        Id: TryFrom<AnyTransId>,
77        Id::Error: Debug,
78    {
79        let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
80        if is_rec {
81            GDeclarationGroup::Rec(gr)
82        } else {
83            assert!(gr.len() == 1);
84            GDeclarationGroup::NonRec(gr[0])
85        }
86    }
87
88    fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
89    where
90        Id: Into<AnyTransId>,
91    {
92        match self {
93            GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
94            GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
95        }
96    }
97}
98
99impl DeclarationGroup {
100    pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
101        use DeclarationGroup::*;
102        match self {
103            Type(gr) => gr.to_mixed(),
104            Fun(gr) => gr.to_mixed(),
105            Global(gr) => gr.to_mixed(),
106            TraitDecl(gr) => gr.to_mixed(),
107            TraitImpl(gr) => gr.to_mixed(),
108            Mixed(gr) => gr.clone(),
109        }
110    }
111
112    pub fn get_ids(&self) -> Vec<AnyTransId> {
113        use DeclarationGroup::*;
114        match self {
115            Type(gr) => gr.get_any_trans_ids(),
116            Fun(gr) => gr.get_any_trans_ids(),
117            Global(gr) => gr.get_any_trans_ids(),
118            TraitDecl(gr) => gr.get_any_trans_ids(),
119            TraitImpl(gr) => gr.get_any_trans_ids(),
120            Mixed(gr) => gr.get_any_trans_ids(),
121        }
122    }
123}
124
125#[derive(Clone, Copy)]
126pub struct DeclInfo {
127    pub is_transparent: bool,
128}
129
130pub type DeclarationsGroups = Vec<DeclarationGroup>;
131
132impl<Id: Display> Display for GDeclarationGroup<Id> {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
134        match self {
135            GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
136            GDeclarationGroup::Rec(ids) => {
137                write!(
138                    f,
139                    "rec: {}",
140                    pretty_display_list(|id| format!("    {id}"), ids)
141                )
142            }
143        }
144    }
145}
146
147impl Display for DeclarationGroup {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
149        match self {
150            DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
151            DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
152            DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
153            DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
154            DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
155            DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
156        }
157    }
158}
159
160#[derive(Visitor)]
161pub struct Deps {
162    dgraph: DiGraphMap<AnyTransId, ()>,
163    // Want to make sure we remember the order of insertion
164    graph: IndexMap<AnyTransId, IndexSet<AnyTransId>>,
165    // We use this when computing the graph
166    current_id: Option<AnyTransId>,
167    // We use this to track the trait impl block the current item belongs to
168    // (if relevant).
169    //
170    // We use this to ignore the references to the parent impl block.
171    //
172    // If we don't do so, when computing our dependency graph we end up with
173    // mutually recursive trait impl blocks/trait method impls in the presence
174    // of associated types (the deepest reason is that we don't normalize the
175    // types we query from rustc when translating the types from function
176    // signatures - we avoid doing so because as of now it makes resolving
177    // the trait params harder: if we get normalized types, we have to
178    // implement a normalizer on our side to make sure we correctly match
179    // types...).
180    //
181    //
182    // For instance, the problem happens if in Rust we have:
183    // ```text
184    // pub trait WithConstTy {
185    //     type W;
186    //     fn f(x: &mut Self::W);
187    // }
188    //
189    // impl WithConstTy for bool {
190    //     type W = u64;
191    //     fn f(_: &mut Self::W) {}
192    // }
193    // ```
194    //
195    // In LLBC we get:
196    //
197    // ```text
198    // impl traits::Bool::0 : traits::WithConstTy<bool>
199    // {
200    //     type W = u64 with []
201    //     fn f = traits::Bool::0::f
202    // }
203    //
204    // fn traits::Bool::0::f<@R0>(@1: &@R0 mut (traits::Bool::0::W)) { .. }
205    // //                                       ^^^^^^^^^^^^^^^
206    // //                                    refers to the trait impl
207    // ```
208    parent_trait_impl: Option<TraitImplId>,
209    parent_trait_decl: Option<TraitDeclId>,
210}
211
212impl Deps {
213    fn new() -> Self {
214        Deps {
215            dgraph: DiGraphMap::new(),
216            graph: IndexMap::new(),
217            current_id: None,
218            parent_trait_impl: None,
219            parent_trait_decl: None,
220        }
221    }
222
223    fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
224        match kind {
225            ItemKind::Regular => {}
226            ItemKind::TraitDecl { trait_ref, .. } => {
227                self.parent_trait_decl = Some(trait_ref.trait_id)
228            }
229            ItemKind::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.impl_id),
230        }
231    }
232    fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
233        self.insert_node(id);
234        self.current_id = Some(id);
235
236        // Add the id of the impl/trait this item belongs to, if necessary
237        use AnyTransId::*;
238        match id {
239            TraitDecl(_) | TraitImpl(_) | Type(_) => (),
240            Global(id) => {
241                if let Some(decl) = ctx.translated.global_decls.get(id) {
242                    self.set_impl_or_trait_id(&decl.kind);
243                }
244            }
245            Fun(id) => {
246                if let Some(decl) = ctx.translated.fun_decls.get(id) {
247                    self.set_impl_or_trait_id(&decl.kind);
248                }
249            }
250        }
251    }
252
253    fn unset_current_id(&mut self) {
254        self.current_id = None;
255        self.parent_trait_impl = None;
256        self.parent_trait_decl = None;
257    }
258
259    fn insert_node(&mut self, id: AnyTransId) {
260        // We have to be careful about duplicate nodes
261        if !self.dgraph.contains_node(id) {
262            self.dgraph.add_node(id);
263            assert!(!self.graph.contains_key(&id));
264            self.graph.insert(id, IndexSet::new());
265        }
266    }
267
268    fn insert_edge(&mut self, id1: AnyTransId) {
269        let id0 = self.current_id.unwrap();
270        self.insert_node(id1);
271        if !self.dgraph.contains_edge(id0, id1) {
272            self.dgraph.add_edge(id0, id1, ());
273            self.graph.get_mut(&id0).unwrap().insert(id1);
274        }
275    }
276}
277
278impl VisitAst for Deps {
279    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
280        self.insert_edge((*id).into());
281    }
282
283    fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
284        self.insert_edge((*id).into());
285    }
286
287    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
288        // If the impl is the impl this item belongs to, we ignore it
289        // TODO: this is not very satisfying but this is the only way we have of preventing
290        // mutually recursive groups between method impls and trait impls in the presence of
291        // associated types...
292        if let Some(impl_id) = &self.parent_trait_impl
293            && impl_id == id
294        {
295            return;
296        }
297        self.insert_edge((*id).into());
298    }
299
300    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
301        // If the trait is the trait this item belongs to, we ignore it. This is to avoid mutually
302        // recursive groups between e.g. traits decls and their globals. We treat methods
303        // specifically.
304        if let Some(trait_id) = &self.parent_trait_decl
305            && trait_id == id
306        {
307            return;
308        }
309        self.insert_edge((*id).into());
310    }
311
312    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
313        self.insert_edge((*id).into());
314    }
315
316    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
317        // Don't look inside because trait impls contain their own id in their name.
318        Continue(())
319    }
320    fn visit_item_kind(&mut self, _: &ItemKind) -> ControlFlow<Self::Break> {
321        // Don't look inside to avoid recording a dependency from a method impl to the impl block
322        // it belongs to.
323        Continue(())
324    }
325}
326
327impl AnyTransId {
328    fn fmt_with_ctx(&self, ctx: &TransformCtx) -> String {
329        use AnyTransId::*;
330        let ctx = ctx.into_fmt();
331        match self {
332            Type(id) => ctx.format_object(*id),
333            Fun(id) => ctx.format_object(*id),
334            Global(id) => ctx.format_object(*id),
335            TraitDecl(id) => ctx.format_object(*id),
336            TraitImpl(id) => ctx.format_object(*id),
337        }
338    }
339}
340
341impl Deps {
342    fn fmt_with_ctx(&self, ctx: &TransformCtx) -> String {
343        self.dgraph
344            .nodes()
345            .map(|node| {
346                let edges = self
347                    .dgraph
348                    .edges(node)
349                    .map(|e| format!("\n  {}", e.1.fmt_with_ctx(ctx)))
350                    .collect::<Vec<String>>()
351                    .join(",");
352
353                format!("{} -> [{}\n]", node.fmt_with_ctx(ctx), edges)
354            })
355            .collect::<Vec<String>>()
356            .join(",\n")
357    }
358}
359
360fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
361    let mut graph = Deps::new();
362    for (id, item) in ctx.translated.all_items_with_ids() {
363        graph.set_current_id(ctx, id);
364        match item {
365            AnyTransItem::Type(..) | AnyTransItem::TraitImpl(..) | AnyTransItem::Global(..) => {
366                let _ = item.drive(&mut graph);
367            }
368            AnyTransItem::Fun(d) => {
369                // Skip `d.is_global_initializer` to avoid incorrect mutual dependencies.
370                // TODO: add `is_global_initializer` to `ItemKind`.
371                let _ = d.signature.drive(&mut graph);
372                let _ = d.body.drive(&mut graph);
373                // FIXME(#514): A method declaration depends on its declaring trait because of its
374                // `Self` clause. While the clause is implicit, we make sure to record the
375                // dependency manually.
376                if let ItemKind::TraitDecl { trait_ref, .. } = &d.kind {
377                    graph.insert_edge(trait_ref.trait_id.into());
378                }
379            }
380            AnyTransItem::TraitDecl(d) => {
381                let TraitDecl {
382                    def_id: _,
383                    item_meta: _,
384                    generics,
385                    parent_clauses,
386                    consts,
387                    const_defaults,
388                    types,
389                    type_defaults,
390                    type_clauses,
391                    methods,
392                } = d;
393                // Visit the traits referenced in the generics
394                let _ = generics.drive(&mut graph);
395
396                // Visit the parent clauses
397                let _ = parent_clauses.drive(&mut graph);
398                assert!(type_clauses.is_empty());
399
400                // Visit the items
401                let _ = consts.drive(&mut graph);
402                let _ = types.drive(&mut graph);
403                let _ = type_defaults.drive(&mut graph);
404
405                // We consider that a trait decl only contains the function/constant signatures.
406                // Therefore we don't explore the default const/method ids.
407                for (_name, gref) in const_defaults {
408                    let _ = gref.generics.drive(&mut graph);
409                }
410                for (_, bound_fn) in methods {
411                    let id = bound_fn.skip_binder.id;
412                    let _ = bound_fn.params.drive(&mut graph);
413                    if let Some(decl) = ctx.translated.fun_decls.get(id) {
414                        let _ = decl.signature.drive(&mut graph);
415                    }
416                }
417            }
418        }
419        graph.unset_current_id();
420    }
421    graph
422}
423
424fn group_declarations_from_scc(
425    _ctx: &TransformCtx,
426    graph: Deps,
427    reordered_sccs: SCCs<AnyTransId>,
428) -> DeclarationsGroups {
429    let reordered_sccs = &reordered_sccs.sccs;
430    let mut reordered_decls: DeclarationsGroups = Vec::new();
431
432    // Iterate over the SCC ids in the proper order
433    for scc in reordered_sccs.iter() {
434        if scc.is_empty() {
435            // This can happen if we failed to translate the item in this group.
436            continue;
437        }
438
439        // Note that the length of an SCC should be at least 1.
440        let mut it = scc.iter();
441        let id0 = *it.next().unwrap();
442        let decl = graph.graph.get(&id0).unwrap();
443
444        // If an SCC has length one, the declaration may be simply recursive:
445        // we determine whether it is the case by checking if the def id is in
446        // its own set of dependencies.
447        let is_mutually_recursive = scc.len() > 1;
448        let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
449        let is_rec = is_mutually_recursive || is_simply_recursive;
450
451        let all_same_kind = scc
452            .iter()
453            .all(|id| id0.variant_index_arity() == id.variant_index_arity());
454        let ids = scc.iter().copied();
455        let group: DeclarationGroup = match id0 {
456            _ if !all_same_kind => {
457                DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
458            }
459            AnyTransId::Type(_) => {
460                DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
461            }
462            AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
463            AnyTransId::Global(_) => {
464                DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
465            }
466            AnyTransId::TraitDecl(_) => {
467                let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
468                // Trait declarations often refer to `Self`, like below,
469                // which means they are often considered as recursive by our
470                // analysis. TODO: do something more precise. What is important
471                // is that we never use the "whole" self clause as argument,
472                // but rather projections over the self clause (like `<Self as Foo>::u`,
473                // in the declaration for `Foo`).
474                if gr.len() == 1 {
475                    DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
476                } else {
477                    DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
478                }
479            }
480            AnyTransId::TraitImpl(_) => {
481                DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
482            }
483        };
484
485        reordered_decls.push(group);
486    }
487    reordered_decls
488}
489
490fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
491    trace!();
492
493    // Step 1: explore the declarations to build the graph
494    let graph = compute_declarations_graph(ctx);
495    trace!("Graph:\n{}\n", graph.fmt_with_ctx(ctx));
496
497    // Step 2: Apply Tarjan's SCC (Strongly Connected Components) algorithm
498    let sccs = tarjan_scc(&graph.dgraph);
499
500    // Step 3: Reorder the declarations in an order as close as possible to the one
501    // given by the user. To be more precise, if we don't need to move
502    // definitions, the order in which we generate the declarations should
503    // be the same as the one in which the user wrote them.
504    // Remark: the [get_id_dependencies] function will be called once per id, meaning
505    // it is ok if it is not very efficient and clones values.
506    let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
507    let all_ids: Vec<AnyTransId> = graph
508        .graph
509        .keys()
510        .copied()
511        // Don't list ids that weren't translated.
512        .filter(|id| ctx.translated.get_item(*id).is_some())
513        .collect();
514    let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
515
516    // Finally, generate the list of declarations
517    let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
518
519    trace!("{:?}", reordered_decls);
520    reordered_decls
521}
522
523pub struct Transform;
524impl TransformPass for Transform {
525    fn transform_ctx(&self, ctx: &mut TransformCtx) {
526        let reordered_decls = compute_reordered_decls(&ctx);
527        ctx.translated.ordered_decls = Some(reordered_decls);
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    #[test]
534    fn test_reorder_sccs1() {
535        use std::collections::BTreeSet as OrdSet;
536        let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
537        let ids = vec![0, 1, 2, 3, 4, 5];
538
539        let get_deps = &|x| match x {
540            0 => vec![3],
541            1 => vec![0, 3],
542            _ => vec![],
543        };
544        let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
545
546        assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
547        assert!(reordered.scc_deps[0] == OrdSet::from([]));
548        assert!(reordered.scc_deps[1] == OrdSet::from([0]));
549        assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
550    }
551}