Skip to main content

charon_lib/transform/add_missing_info/
reorder_decls.rs

1//! Compute an ordering on declarations that:
2//! - Detects mutually-recursive groups;
3//! - Always orders an item before any of its uses (except for recursive cases);
4//! - Otherwise keeps a stable order.
5//!
6//! Aeneas needs this because proof assistant languages are sensitive to declaration order and need
7//! to be explicit about mutual recursion. This should come useful for translation to any other
8//! language with these properties.
9use crate::common::*;
10use crate::transform::TransformCtx;
11use crate::ullbc_ast::*;
12use derive_generic_visitor::*;
13use itertools::Itertools;
14use petgraph::graphmap::DiGraphMap;
15use std::collections::{HashMap, HashSet};
16use std::fmt::{Debug, Display, Error};
17use std::vec::Vec;
18
19use crate::transform::ctx::TransformPass;
20
21impl<Id: Copy> GDeclarationGroup<Id> {
22    pub fn get_ids(&self) -> &[Id] {
23        use GDeclarationGroup::*;
24        match self {
25            NonRec(id) => std::slice::from_ref(id),
26            Rec(ids) => ids.as_slice(),
27        }
28    }
29
30    pub fn get_any_trans_ids(&self) -> Vec<ItemId>
31    where
32        Id: Into<ItemId>,
33    {
34        self.get_ids().iter().copied().map(|id| id.into()).collect()
35    }
36
37    fn make_group(is_rec: bool, ids: Vec<ItemId>) -> Self
38    where
39        Id: TryFrom<ItemId>,
40        Id::Error: Debug,
41    {
42        let ids: Vec<_> = ids.into_iter().map(|x| x.try_into().unwrap()).collect();
43        if is_rec {
44            GDeclarationGroup::Rec(ids)
45        } else {
46            assert!(ids.len() == 1);
47            GDeclarationGroup::NonRec(ids[0])
48        }
49    }
50
51    fn to_mixed(&self) -> GDeclarationGroup<ItemId>
52    where
53        Id: Into<ItemId>,
54    {
55        match self {
56            GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
57            GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
58        }
59    }
60}
61
62impl DeclarationGroup {
63    fn make_group(is_rec: bool, ids: Vec<ItemId>) -> Self {
64        let id0 = ids[0];
65        let all_same_kind = ids
66            .iter()
67            .all(|id| id0.variant_index_arity() == id.variant_index_arity());
68        match id0 {
69            _ if !all_same_kind => {
70                DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
71            }
72            ItemId::Type(_) => DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids)),
73            ItemId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
74            ItemId::Global(_) => {
75                DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
76            }
77            ItemId::TraitDecl(_) => {
78                DeclarationGroup::TraitDecl(GDeclarationGroup::make_group(is_rec, ids))
79            }
80            ItemId::TraitImpl(_) => {
81                DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
82            }
83        }
84    }
85
86    pub fn to_mixed_group(&self) -> GDeclarationGroup<ItemId> {
87        use DeclarationGroup::*;
88        match self {
89            Type(gr) => gr.to_mixed(),
90            Fun(gr) => gr.to_mixed(),
91            Global(gr) => gr.to_mixed(),
92            TraitDecl(gr) => gr.to_mixed(),
93            TraitImpl(gr) => gr.to_mixed(),
94            Mixed(gr) => gr.clone(),
95        }
96    }
97
98    pub fn get_ids(&self) -> Vec<ItemId> {
99        use DeclarationGroup::*;
100        match self {
101            Type(gr) => gr.get_any_trans_ids(),
102            Fun(gr) => gr.get_any_trans_ids(),
103            Global(gr) => gr.get_any_trans_ids(),
104            TraitDecl(gr) => gr.get_any_trans_ids(),
105            TraitImpl(gr) => gr.get_any_trans_ids(),
106            Mixed(gr) => gr.get_any_trans_ids(),
107        }
108    }
109}
110
111impl<Id: Display> Display for GDeclarationGroup<Id> {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
113        match self {
114            GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
115            GDeclarationGroup::Rec(ids) => {
116                write!(
117                    f,
118                    "rec: {}",
119                    pretty_display_list(|id| format!("    {id}"), ids)
120                )
121            }
122        }
123    }
124}
125
126impl Display for DeclarationGroup {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
128        match self {
129            DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
130            DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
131            DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
132            DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
133            DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
134            DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
135        }
136    }
137}
138
139#[derive(Default)]
140pub struct Deps {
141    /// The dependency graph between translated items. We're careful to only add items that got
142    /// translated.
143    graph: DiGraphMap<ItemId, ()>,
144    unprocessed: Vec<ItemId>,
145    visited: HashSet<ItemId>,
146}
147
148/// We use this when computing the graph
149#[derive(Visitor)]
150pub struct DepsForItem<'a> {
151    ctx: &'a TransformCtx,
152    deps: &'a mut Deps,
153    current_id: ItemId,
154    // Each item countains its own id; we dont' want to count that as a self-reference. Hence the
155    // first time we see its own id, we skip.
156    seen_current_id: bool,
157    // We use this to track the trait impl block the current item belongs to
158    // (if relevant).
159    //
160    // We use this to ignore the references to the parent impl block.
161    //
162    // If we don't do so, when computing our dependency graph we end up with
163    // mutually recursive trait impl blocks/trait method impls in the presence
164    // of associated types (the deepest reason is that we don't normalize the
165    // types we query from rustc when translating the types from function
166    // signatures - we avoid doing so because as of now it makes resolving
167    // the trait params harder: if we get normalized types, we have to
168    // implement a normalizer on our side to make sure we correctly match
169    // types...).
170    //
171    //
172    // For instance, the problem happens if in Rust we have:
173    // ```text
174    // pub trait WithConstTy {
175    //     type W;
176    //     fn f(x: &mut Self::W);
177    // }
178    //
179    // impl WithConstTy for bool {
180    //     type W = u64;
181    //     fn f(_: &mut Self::W) {}
182    // }
183    // ```
184    //
185    // In LLBC we get:
186    //
187    // ```text
188    // impl traits::Bool::0 : traits::WithConstTy<bool>
189    // {
190    //     type W = u64 with []
191    //     fn f = traits::Bool::0::f
192    // }
193    //
194    // fn traits::Bool::0::f<@R0>(@1: &@R0 mut (traits::Bool::0::W)) { .. }
195    // //                                       ^^^^^^^^^^^^^^^
196    // //                                    refers to the trait impl
197    // ```
198    parent_trait_impl: Option<TraitImplId>,
199    parent_trait_decl: Option<TraitDeclId>,
200}
201
202impl Deps {
203    fn visitor_for_item<'a>(
204        &'a mut self,
205        ctx: &'a TransformCtx,
206        item: ItemRef<'_>,
207    ) -> DepsForItem<'a> {
208        let current_id = item.id();
209        self.graph.add_node(current_id);
210
211        let mut for_item = DepsForItem {
212            ctx,
213            deps: self,
214            seen_current_id: false,
215            current_id,
216            parent_trait_impl: None,
217            parent_trait_decl: None,
218        };
219
220        // Add the id of the impl/trait this item belongs to, if necessary
221        match item.parent_info() {
222            ItemSource::TraitDecl { trait_ref, .. } => {
223                for_item.parent_trait_decl = Some(trait_ref.id)
224            }
225            ItemSource::TraitImpl { impl_ref, .. } => {
226                for_item.parent_trait_impl = Some(impl_ref.id)
227            }
228            _ => {}
229        }
230
231        for_item
232    }
233}
234
235impl DepsForItem<'_> {
236    fn insert_node(&mut self, tgt: impl Into<ItemId>) {
237        let tgt = tgt.into();
238        // Only add translated items.
239        if self.ctx.translated.get_item(tgt).is_some() && !self.deps.visited.contains(&tgt) {
240            self.deps.unprocessed.push(tgt);
241        }
242    }
243    fn insert_edge(&mut self, tgt: impl Into<ItemId>) {
244        let tgt = tgt.into();
245        if tgt == self.current_id && !self.seen_current_id {
246            // Each item contains its own id; this is a hack to avoid considering that as a self
247            // loop.
248            self.seen_current_id = true;
249            return;
250        }
251        self.insert_node(tgt);
252        // Only add translated items.
253        if self.ctx.translated.get_item(tgt).is_some() {
254            self.deps.graph.add_edge(self.current_id, tgt, ());
255        }
256    }
257}
258
259impl VisitAst for DepsForItem<'_> {
260    fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
261        self.insert_edge(*id);
262    }
263
264    fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
265        self.insert_edge(*id);
266    }
267
268    fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
269        // If the impl is the impl this item belongs to, we ignore it
270        // TODO: this is not very satisfying but this is the only way we have of preventing
271        // mutually recursive groups between method impls and trait impls in the presence of
272        // associated types...
273        if self.parent_trait_impl != Some(*id) {
274            self.insert_edge(*id);
275        }
276    }
277
278    fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
279        // If the trait is the trait this item belongs to, we ignore it. This is to avoid mutually
280        // recursive groups between e.g. traits decls and their globals. We treat methods
281        // specifically.
282        if self.parent_trait_decl != Some(*id) {
283            self.insert_edge(*id);
284        }
285    }
286
287    fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
288        self.insert_edge(*id);
289    }
290
291    fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
292        // Don't look inside because trait impls contain their own id in their name.
293        Continue(())
294    }
295    fn visit_item_source(&mut self, _: &ItemSource) -> ControlFlow<Self::Break> {
296        // Don't look inside to avoid recording a dependency from a method impl to the impl block
297        // it belongs to.
298        Continue(())
299    }
300}
301
302fn compute_declarations_graph(ctx: &TransformCtx) -> DiGraphMap<ItemId, ()> {
303    let mut deps = Deps::default();
304    // Start from the items included in `start_from`. We've mostly only translated items accessible
305    // from that, but some passes render items inaccessible again, which we filter out here.
306    deps.unprocessed = ctx
307        .translated
308        .all_items()
309        .filter(|item| {
310            ctx.options
311                .start_from
312                .iter()
313                .any(|pat| pat.matches(&ctx.translated, item.item_meta()))
314        })
315        .map(|item| item.id())
316        .collect();
317
318    // Explore reachable items.
319    while let Some(id) = deps.unprocessed.pop() {
320        if !deps.visited.insert(id) {
321            continue;
322        }
323        let Some(item) = ctx.translated.get_item(id) else {
324            continue;
325        };
326        let mut visitor = deps.visitor_for_item(ctx, item);
327        match item {
328            ItemRef::Type(..) | ItemRef::TraitImpl(..) | ItemRef::Global(..) => {
329                let _ = item.drive(&mut visitor);
330            }
331            ItemRef::Fun(d) => {
332                let FunDecl {
333                    def_id,
334                    item_meta: _,
335                    generics,
336                    signature,
337                    src,
338                    is_global_initializer: _,
339                    body,
340                } = d;
341                let _ = def_id.drive(&mut visitor); // For `seen_current_id`
342                // Skip `d.is_global_initializer` to avoid incorrect mutual dependencies.
343                // TODO: add `is_global_initializer` to `ItemSource`.
344                let _ = generics.drive(&mut visitor);
345                let _ = signature.drive(&mut visitor);
346                let _ = body.drive(&mut visitor);
347                if let ItemSource::TraitDecl { trait_ref, .. } = src {
348                    visitor.insert_edge(trait_ref.id);
349                }
350            }
351            ItemRef::TraitDecl(d) => {
352                let TraitDecl {
353                    def_id,
354                    item_meta: _,
355                    generics,
356                    implied_clauses: parent_clauses,
357                    consts,
358                    types,
359                    methods,
360                    vtable,
361                } = d;
362                let _ = def_id.drive(&mut visitor); // For `seen_current_id`
363                // Visit the traits referenced in the generics
364                let _ = generics.drive(&mut visitor);
365
366                // Visit the parent clauses
367                let _ = parent_clauses.drive(&mut visitor);
368
369                // Visit the items
370                let _ = types.drive(&mut visitor);
371                let _ = vtable.drive(&mut visitor);
372
373                // We consider that a trait decl only contains the function/constant signatures.
374                // Therefore we don't explore the default const/method ids.
375                for assoc_const in consts {
376                    let TraitAssocConst {
377                        name: _,
378                        attr_info: _,
379                        ty,
380                        default,
381                    } = assoc_const;
382                    let _ = ty.drive(&mut visitor);
383                    if let Some(gref) = default {
384                        visitor.insert_node(gref.id); // Still count the item as reachable.
385                        let _ = gref.generics.drive(&mut visitor);
386                    }
387                }
388                for bound_method in methods {
389                    let id = bound_method.skip_binder.item.id;
390                    visitor.insert_node(id); // Still count the item as reachable.
391                    let _ = bound_method.params.drive(&mut visitor);
392                    if let Some(decl) = ctx.translated.fun_decls.get(id) {
393                        let _ = decl.signature.drive(&mut visitor);
394                    }
395                }
396            }
397        }
398    }
399    deps.graph
400}
401
402fn compute_reordered_decls(ctx: &mut TransformCtx) -> DeclarationsGroups {
403    // Build the graph of dependencies between items.
404    let graph = compute_declarations_graph(ctx);
405
406    // Pre-sort files to limit the number of costly string comparisons. Maps file ids to an index
407    // that reflects ordering on the crates (with `core` and `std` sorted first) and file names.
408    let sorted_file_ids: IndexMap<FileId, usize> = ctx
409        .translated
410        .files
411        .indices()
412        .sorted_by_cached_key(|&file_id| {
413            let file = &ctx.translated.files[file_id];
414            let is_std = file.crate_name == "std" || file.crate_name == "core";
415            (!is_std, &file.crate_name, &file.name)
416        })
417        .enumerate()
418        .sorted_by_key(|(_i, file_id)| *file_id)
419        .map(|(i, _file_id)| i)
420        .collect();
421    assert_eq!(ctx.translated.files.len(), sorted_file_ids.slot_count());
422
423    // We sort items as follows: std items, then items from foreign crates (sorted by crate name),
424    // then local items. Within a crate, we sort by file then by source order.
425    let sort_by = |item: &ItemRef| {
426        let item_meta = item.item_meta();
427        let span = item_meta.span.data;
428        let file_name_order = sorted_file_ids.get(span.file_id);
429        (item_meta.is_local, file_name_order, span.beg, item.id())
430    };
431    // We record for each item the order in which we're sorting it, to make `sort_by` cheap.
432    let item_sorted_index: HashMap<ItemId, usize> = ctx
433        .translated
434        .all_items()
435        .sorted_by_cached_key(sort_by)
436        .enumerate()
437        .map(|(i, item)| (item.id(), i))
438        .collect();
439    let sort_by = |id: &ItemId| item_sorted_index.get(id).unwrap();
440
441    // Compute SCCs (Strongly Connected Components) for the graph in a way that matches the chosen
442    // order as much as possible.
443    let reordered_sccs = super::sccs::ordered_scc(&graph, sort_by);
444
445    // Convert to a list of declarations.
446    let reordered_decls = reordered_sccs
447        .into_iter()
448        // This can happen if we failed to translate the item in this group.
449        .filter(|scc| !scc.is_empty())
450        .map(|scc| {
451            // If an SCC has length one, the declaration may be simply recursive: we determine whether
452            // it is the case by checking if the def id is in its own set of dependencies.
453            // Trait declarations often refer to `Self`, which means they are often considered as
454            // recursive by our analysis. So we cheat an declare them non-recursive.
455            // TODO: do something more precise. What is important is that we never use the "whole" self
456            // clause as argument, but rather projections over the self clause (like `<Self as
457            // Foo>::u`, in the declaration for `Foo`).
458            let id0 = scc[0];
459            let is_non_rec =
460                scc.len() == 1 && (id0.is_trait_decl() || !graph.neighbors(id0).contains(&id0));
461
462            DeclarationGroup::make_group(!is_non_rec, scc)
463        })
464        .collect();
465
466    trace!("{:?}", reordered_decls);
467    reordered_decls
468}
469
470pub struct Transform;
471impl TransformPass for Transform {
472    fn transform_ctx(&self, ctx: &mut TransformCtx) {
473        let reordered_decls = compute_reordered_decls(ctx);
474        ctx.translated.ordered_decls = Some(reordered_decls);
475    }
476}