1use crate::common::*;
10use crate::formatter::{FmtCtx, IntoFormatter};
11use crate::graphs::*;
12use crate::pretty::FmtWithCtx;
13use crate::transform::TransformCtx;
14use crate::ullbc_ast::*;
15use derive_generic_visitor::*;
16use indexmap::{IndexMap, IndexSet};
17use itertools::Itertools;
18use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
19use petgraph::algo::tarjan_scc;
20use petgraph::graphmap::DiGraphMap;
21use serde::{Deserialize, Serialize};
22use std::fmt::{Debug, Display, Error};
23use std::vec::Vec;
24
25use super::ctx::TransformPass;
26
27#[derive(
30 Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
31)]
32#[charon::variants_suffix("Group")]
33pub enum GDeclarationGroup<Id> {
34 NonRec(Id),
36 Rec(Vec<Id>),
38}
39
40#[derive(
42 Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
43)]
44#[charon::variants_suffix("Group")]
45pub enum DeclarationGroup {
46 Type(GDeclarationGroup<TypeDeclId>),
48 Fun(GDeclarationGroup<FunDeclId>),
50 Global(GDeclarationGroup<GlobalDeclId>),
52 TraitDecl(GDeclarationGroup<TraitDeclId>),
54 TraitImpl(GDeclarationGroup<TraitImplId>),
56 Mixed(GDeclarationGroup<AnyTransId>),
58}
59
60impl<Id: Copy> GDeclarationGroup<Id> {
61 pub fn get_ids(&self) -> &[Id] {
62 use GDeclarationGroup::*;
63 match self {
64 NonRec(id) => std::slice::from_ref(id),
65 Rec(ids) => ids.as_slice(),
66 }
67 }
68
69 pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
70 where
71 Id: Into<AnyTransId>,
72 {
73 self.get_ids().iter().copied().map(|id| id.into()).collect()
74 }
75
76 fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
77 where
78 Id: TryFrom<AnyTransId>,
79 Id::Error: Debug,
80 {
81 let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
82 if is_rec {
83 GDeclarationGroup::Rec(gr)
84 } else {
85 assert!(gr.len() == 1);
86 GDeclarationGroup::NonRec(gr[0])
87 }
88 }
89
90 fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
91 where
92 Id: Into<AnyTransId>,
93 {
94 match self {
95 GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
96 GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
97 }
98 }
99}
100
101impl DeclarationGroup {
102 pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
103 use DeclarationGroup::*;
104 match self {
105 Type(gr) => gr.to_mixed(),
106 Fun(gr) => gr.to_mixed(),
107 Global(gr) => gr.to_mixed(),
108 TraitDecl(gr) => gr.to_mixed(),
109 TraitImpl(gr) => gr.to_mixed(),
110 Mixed(gr) => gr.clone(),
111 }
112 }
113
114 pub fn get_ids(&self) -> Vec<AnyTransId> {
115 use DeclarationGroup::*;
116 match self {
117 Type(gr) => gr.get_any_trans_ids(),
118 Fun(gr) => gr.get_any_trans_ids(),
119 Global(gr) => gr.get_any_trans_ids(),
120 TraitDecl(gr) => gr.get_any_trans_ids(),
121 TraitImpl(gr) => gr.get_any_trans_ids(),
122 Mixed(gr) => gr.get_any_trans_ids(),
123 }
124 }
125}
126
127#[derive(Clone, Copy)]
128pub struct DeclInfo {
129 pub is_transparent: bool,
130}
131
132pub type DeclarationsGroups = Vec<DeclarationGroup>;
133
134impl<Id: Display> Display for GDeclarationGroup<Id> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
136 match self {
137 GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
138 GDeclarationGroup::Rec(ids) => {
139 write!(
140 f,
141 "rec: {}",
142 pretty_display_list(|id| format!(" {id}"), ids)
143 )
144 }
145 }
146 }
147}
148
149impl Display for DeclarationGroup {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
151 match self {
152 DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
153 DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
154 DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
155 DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
156 DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
157 DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
158 }
159 }
160}
161
162#[derive(Visitor)]
163pub struct Deps {
164 dgraph: DiGraphMap<AnyTransId, ()>,
165 graph: IndexMap<AnyTransId, IndexSet<AnyTransId>>,
167 current_id: Option<AnyTransId>,
169 parent_trait_impl: Option<TraitImplId>,
211 parent_trait_decl: Option<TraitDeclId>,
212}
213
214impl Deps {
215 fn new() -> Self {
216 Deps {
217 dgraph: DiGraphMap::new(),
218 graph: IndexMap::new(),
219 current_id: None,
220 parent_trait_impl: None,
221 parent_trait_decl: None,
222 }
223 }
224
225 fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
226 match kind {
227 ItemKind::TraitDecl { trait_ref, .. } => self.parent_trait_decl = Some(trait_ref.id),
228 ItemKind::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.id),
229 _ => {}
230 }
231 }
232 fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
233 self.insert_node(id);
234 self.current_id = Some(id);
235
236 use AnyTransId::*;
238 match id {
239 TraitDecl(_) | TraitImpl(_) | Type(_) => (),
240 Global(id) => {
241 if let Some(decl) = ctx.translated.global_decls.get(id) {
242 self.set_impl_or_trait_id(&decl.kind);
243 }
244 }
245 Fun(id) => {
246 if let Some(decl) = ctx.translated.fun_decls.get(id) {
247 self.set_impl_or_trait_id(&decl.kind);
248 }
249 }
250 }
251 }
252
253 fn unset_current_id(&mut self) {
254 self.current_id = None;
255 self.parent_trait_impl = None;
256 self.parent_trait_decl = None;
257 }
258
259 fn insert_node(&mut self, id: AnyTransId) {
260 if !self.dgraph.contains_node(id) {
262 self.dgraph.add_node(id);
263 assert!(!self.graph.contains_key(&id));
264 self.graph.insert(id, IndexSet::new());
265 }
266 }
267
268 fn insert_edge(&mut self, id1: AnyTransId) {
269 let id0 = self.current_id.unwrap();
270 self.insert_node(id1);
271 if !self.dgraph.contains_edge(id0, id1) {
272 self.dgraph.add_edge(id0, id1, ());
273 self.graph.get_mut(&id0).unwrap().insert(id1);
274 }
275 }
276}
277
278impl VisitAst for Deps {
279 fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
280 self.insert_edge((*id).into());
281 }
282
283 fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
284 self.insert_edge((*id).into());
285 }
286
287 fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
288 if let Some(impl_id) = &self.parent_trait_impl
293 && impl_id == id
294 {
295 return;
296 }
297 self.insert_edge((*id).into());
298 }
299
300 fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
301 if let Some(trait_id) = &self.parent_trait_decl
305 && trait_id == id
306 {
307 return;
308 }
309 self.insert_edge((*id).into());
310 }
311
312 fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
313 self.insert_edge((*id).into());
314 }
315
316 fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
317 Continue(())
319 }
320 fn visit_item_kind(&mut self, _: &ItemKind) -> ControlFlow<Self::Break> {
321 Continue(())
324 }
325}
326
327impl Deps {
328 fn fmt_with_ctx(&self, ctx: &FmtCtx<'_>) -> String {
329 self.dgraph
330 .nodes()
331 .map(|node| {
332 let edges = self
333 .dgraph
334 .edges(node)
335 .map(|e| format!("\n {}", e.1.with_ctx(ctx)))
336 .collect::<Vec<String>>()
337 .join(",");
338
339 format!("{} -> [{}\n]", node.with_ctx(ctx), edges)
340 })
341 .format(",\n")
342 .to_string()
343 }
344}
345
346fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
347 let mut sorted_items = ctx.translated.all_items().collect_vec();
351 let sorted_file_ids: Vector<FileId, usize> = ctx
354 .translated
355 .files
356 .all_indices()
357 .sorted_by_key(|&file_id| {
358 let file = &ctx.translated.files[file_id];
359 let is_std = file.crate_name == "std" || file.crate_name == "core";
360 (!is_std, &file.crate_name, &file.name)
361 })
362 .enumerate()
363 .sorted_by_key(|(_i, file_id)| *file_id)
364 .map(|(i, _file_id)| i)
365 .collect();
366 assert_eq!(
367 ctx.translated.files.slot_count(),
368 sorted_file_ids.slot_count()
369 );
370 sorted_items.sort_by_key(|item| {
371 let item_meta = item.item_meta();
372 let span = item_meta.span.span;
373 let file_name_order = sorted_file_ids[span.file_id];
374 (item_meta.is_local, file_name_order, span.beg)
375 });
376
377 let mut graph = Deps::new();
378 for item in sorted_items {
379 let id = item.id();
380 graph.set_current_id(ctx, id);
381 match item {
382 AnyTransItem::Type(..) | AnyTransItem::TraitImpl(..) | AnyTransItem::Global(..) => {
383 let _ = item.drive(&mut graph);
384 }
385 AnyTransItem::Fun(d) => {
386 let _ = d.signature.drive(&mut graph);
389 let _ = d.body.drive(&mut graph);
390 if let ItemKind::TraitDecl { trait_ref, .. } = &d.kind {
394 graph.insert_edge(trait_ref.id.into());
395 }
396 }
397 AnyTransItem::TraitDecl(d) => {
398 let TraitDecl {
399 def_id: _,
400 item_meta: _,
401 generics,
402 parent_clauses,
403 consts,
404 const_defaults,
405 types,
406 type_defaults,
407 type_clauses,
408 methods,
409 } = d;
410 let _ = generics.drive(&mut graph);
412
413 let _ = parent_clauses.drive(&mut graph);
415 assert!(type_clauses.is_empty());
416
417 let _ = consts.drive(&mut graph);
419 let _ = types.drive(&mut graph);
420 let _ = type_defaults.drive(&mut graph);
421
422 for (_name, gref) in const_defaults {
425 let _ = gref.generics.drive(&mut graph);
426 }
427 for (_, bound_fn) in methods {
428 let id = bound_fn.skip_binder.id;
429 let _ = bound_fn.params.drive(&mut graph);
430 if let Some(decl) = ctx.translated.fun_decls.get(id) {
431 let _ = decl.signature.drive(&mut graph);
432 }
433 }
434 }
435 }
436 graph.unset_current_id();
437 }
438 graph
439}
440
441fn group_declarations_from_scc(
442 _ctx: &TransformCtx,
443 graph: Deps,
444 reordered_sccs: SCCs<AnyTransId>,
445) -> DeclarationsGroups {
446 let reordered_sccs = &reordered_sccs.sccs;
447 let mut reordered_decls: DeclarationsGroups = Vec::new();
448
449 for scc in reordered_sccs.iter() {
451 if scc.is_empty() {
452 continue;
454 }
455
456 let mut it = scc.iter();
458 let id0 = *it.next().unwrap();
459 let decl = graph.graph.get(&id0).unwrap();
460
461 let is_mutually_recursive = scc.len() > 1;
465 let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
466 let is_rec = is_mutually_recursive || is_simply_recursive;
467
468 let all_same_kind = scc
469 .iter()
470 .all(|id| id0.variant_index_arity() == id.variant_index_arity());
471 let ids = scc.iter().copied();
472 let group: DeclarationGroup = match id0 {
473 _ if !all_same_kind => {
474 DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
475 }
476 AnyTransId::Type(_) => {
477 DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
478 }
479 AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
480 AnyTransId::Global(_) => {
481 DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
482 }
483 AnyTransId::TraitDecl(_) => {
484 let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
485 if gr.len() == 1 {
492 DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
493 } else {
494 DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
495 }
496 }
497 AnyTransId::TraitImpl(_) => {
498 DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
499 }
500 };
501
502 reordered_decls.push(group);
503 }
504 reordered_decls
505}
506
507fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
508 trace!();
509
510 let graph = compute_declarations_graph(ctx);
512 trace!("Graph:\n{}\n", graph.fmt_with_ctx(&ctx.into_fmt()));
513
514 let sccs = tarjan_scc(&graph.dgraph);
516
517 let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
524 let all_ids: Vec<AnyTransId> = graph
525 .graph
526 .keys()
527 .copied()
528 .filter(|id| ctx.translated.get_item(*id).is_some())
530 .collect();
531 let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
532
533 let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
535
536 trace!("{:?}", reordered_decls);
537 reordered_decls
538}
539
540pub struct Transform;
541impl TransformPass for Transform {
542 fn transform_ctx(&self, ctx: &mut TransformCtx) {
543 let reordered_decls = compute_reordered_decls(&ctx);
544 ctx.translated.ordered_decls = Some(reordered_decls);
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 #[test]
551 fn test_reorder_sccs1() {
552 use std::collections::BTreeSet as OrdSet;
553 let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
554 let ids = vec![0, 1, 2, 3, 4, 5];
555
556 let get_deps = &|x| match x {
557 0 => vec![3],
558 1 => vec![0, 3],
559 _ => vec![],
560 };
561 let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
562
563 assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
564 assert!(reordered.scc_deps[0] == OrdSet::from([]));
565 assert!(reordered.scc_deps[1] == OrdSet::from([0]));
566 assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
567 }
568}