1use crate::common::*;
10use crate::formatter::{FmtCtx, IntoFormatter};
11use crate::graphs::*;
12use crate::pretty::FmtWithCtx;
13use crate::transform::TransformCtx;
14use crate::ullbc_ast::*;
15use derive_generic_visitor::*;
16use indexmap::{IndexMap, IndexSet};
17use itertools::Itertools;
18use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
19use petgraph::algo::tarjan_scc;
20use petgraph::graphmap::DiGraphMap;
21use serde::{Deserialize, Serialize};
22use std::fmt::{Debug, Display, Error};
23use std::vec::Vec;
24
25use super::ctx::TransformPass;
26
27#[derive(
30 Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
31)]
32#[charon::variants_suffix("Group")]
33pub enum GDeclarationGroup<Id> {
34 NonRec(Id),
36 Rec(Vec<Id>),
38}
39
40#[derive(
42 Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
43)]
44#[charon::variants_suffix("Group")]
45pub enum DeclarationGroup {
46 Type(GDeclarationGroup<TypeDeclId>),
48 Fun(GDeclarationGroup<FunDeclId>),
50 Global(GDeclarationGroup<GlobalDeclId>),
52 TraitDecl(GDeclarationGroup<TraitDeclId>),
54 TraitImpl(GDeclarationGroup<TraitImplId>),
56 Mixed(GDeclarationGroup<AnyTransId>),
58}
59
60impl<Id: Copy> GDeclarationGroup<Id> {
61 pub fn get_ids(&self) -> &[Id] {
62 use GDeclarationGroup::*;
63 match self {
64 NonRec(id) => std::slice::from_ref(id),
65 Rec(ids) => ids.as_slice(),
66 }
67 }
68
69 pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
70 where
71 Id: Into<AnyTransId>,
72 {
73 self.get_ids().iter().copied().map(|id| id.into()).collect()
74 }
75
76 fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
77 where
78 Id: TryFrom<AnyTransId>,
79 Id::Error: Debug,
80 {
81 let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
82 if is_rec {
83 GDeclarationGroup::Rec(gr)
84 } else {
85 assert!(gr.len() == 1);
86 GDeclarationGroup::NonRec(gr[0])
87 }
88 }
89
90 fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
91 where
92 Id: Into<AnyTransId>,
93 {
94 match self {
95 GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
96 GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
97 }
98 }
99}
100
101impl DeclarationGroup {
102 pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
103 use DeclarationGroup::*;
104 match self {
105 Type(gr) => gr.to_mixed(),
106 Fun(gr) => gr.to_mixed(),
107 Global(gr) => gr.to_mixed(),
108 TraitDecl(gr) => gr.to_mixed(),
109 TraitImpl(gr) => gr.to_mixed(),
110 Mixed(gr) => gr.clone(),
111 }
112 }
113
114 pub fn get_ids(&self) -> Vec<AnyTransId> {
115 use DeclarationGroup::*;
116 match self {
117 Type(gr) => gr.get_any_trans_ids(),
118 Fun(gr) => gr.get_any_trans_ids(),
119 Global(gr) => gr.get_any_trans_ids(),
120 TraitDecl(gr) => gr.get_any_trans_ids(),
121 TraitImpl(gr) => gr.get_any_trans_ids(),
122 Mixed(gr) => gr.get_any_trans_ids(),
123 }
124 }
125}
126
127#[derive(Clone, Copy)]
128pub struct DeclInfo {
129 pub is_transparent: bool,
130}
131
132pub type DeclarationsGroups = Vec<DeclarationGroup>;
133
134impl<Id: Display> Display for GDeclarationGroup<Id> {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
136 match self {
137 GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
138 GDeclarationGroup::Rec(ids) => {
139 write!(
140 f,
141 "rec: {}",
142 pretty_display_list(|id| format!(" {id}"), ids)
143 )
144 }
145 }
146 }
147}
148
149impl Display for DeclarationGroup {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
151 match self {
152 DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
153 DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
154 DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
155 DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
156 DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
157 DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
158 }
159 }
160}
161
162#[derive(Visitor)]
163pub struct Deps {
164 dgraph: DiGraphMap<AnyTransId, ()>,
165 graph: IndexMap<AnyTransId, IndexSet<AnyTransId>>,
167 current_id: Option<AnyTransId>,
169 parent_trait_impl: Option<TraitImplId>,
211 parent_trait_decl: Option<TraitDeclId>,
212}
213
214impl Deps {
215 fn new() -> Self {
216 Deps {
217 dgraph: DiGraphMap::new(),
218 graph: IndexMap::new(),
219 current_id: None,
220 parent_trait_impl: None,
221 parent_trait_decl: None,
222 }
223 }
224
225 fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
226 match kind {
227 ItemKind::TraitDecl { trait_ref, .. } => {
228 self.parent_trait_decl = Some(trait_ref.trait_id)
229 }
230 ItemKind::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.impl_id),
231 _ => {}
232 }
233 }
234 fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
235 self.insert_node(id);
236 self.current_id = Some(id);
237
238 use AnyTransId::*;
240 match id {
241 TraitDecl(_) | TraitImpl(_) | Type(_) => (),
242 Global(id) => {
243 if let Some(decl) = ctx.translated.global_decls.get(id) {
244 self.set_impl_or_trait_id(&decl.kind);
245 }
246 }
247 Fun(id) => {
248 if let Some(decl) = ctx.translated.fun_decls.get(id) {
249 self.set_impl_or_trait_id(&decl.kind);
250 }
251 }
252 }
253 }
254
255 fn unset_current_id(&mut self) {
256 self.current_id = None;
257 self.parent_trait_impl = None;
258 self.parent_trait_decl = None;
259 }
260
261 fn insert_node(&mut self, id: AnyTransId) {
262 if !self.dgraph.contains_node(id) {
264 self.dgraph.add_node(id);
265 assert!(!self.graph.contains_key(&id));
266 self.graph.insert(id, IndexSet::new());
267 }
268 }
269
270 fn insert_edge(&mut self, id1: AnyTransId) {
271 let id0 = self.current_id.unwrap();
272 self.insert_node(id1);
273 if !self.dgraph.contains_edge(id0, id1) {
274 self.dgraph.add_edge(id0, id1, ());
275 self.graph.get_mut(&id0).unwrap().insert(id1);
276 }
277 }
278}
279
280impl VisitAst for Deps {
281 fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
282 self.insert_edge((*id).into());
283 }
284
285 fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
286 self.insert_edge((*id).into());
287 }
288
289 fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
290 if let Some(impl_id) = &self.parent_trait_impl
295 && impl_id == id
296 {
297 return;
298 }
299 self.insert_edge((*id).into());
300 }
301
302 fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
303 if let Some(trait_id) = &self.parent_trait_decl
307 && trait_id == id
308 {
309 return;
310 }
311 self.insert_edge((*id).into());
312 }
313
314 fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
315 self.insert_edge((*id).into());
316 }
317
318 fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
319 Continue(())
321 }
322 fn visit_item_kind(&mut self, _: &ItemKind) -> ControlFlow<Self::Break> {
323 Continue(())
326 }
327}
328
329impl Deps {
330 fn fmt_with_ctx(&self, ctx: &FmtCtx<'_>) -> String {
331 self.dgraph
332 .nodes()
333 .map(|node| {
334 let edges = self
335 .dgraph
336 .edges(node)
337 .map(|e| format!("\n {}", e.1.with_ctx(ctx)))
338 .collect::<Vec<String>>()
339 .join(",");
340
341 format!("{} -> [{}\n]", node.with_ctx(ctx), edges)
342 })
343 .format(",\n")
344 .to_string()
345 }
346}
347
348fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
349 let mut graph = Deps::new();
350 for (id, item) in ctx.translated.all_items_with_ids() {
351 graph.set_current_id(ctx, id);
352 match item {
353 AnyTransItem::Type(..) | AnyTransItem::TraitImpl(..) | AnyTransItem::Global(..) => {
354 let _ = item.drive(&mut graph);
355 }
356 AnyTransItem::Fun(d) => {
357 let _ = d.signature.drive(&mut graph);
360 let _ = d.body.drive(&mut graph);
361 if let ItemKind::TraitDecl { trait_ref, .. } = &d.kind {
365 graph.insert_edge(trait_ref.trait_id.into());
366 }
367 }
368 AnyTransItem::TraitDecl(d) => {
369 let TraitDecl {
370 def_id: _,
371 item_meta: _,
372 generics,
373 parent_clauses,
374 consts,
375 const_defaults,
376 types,
377 type_defaults,
378 type_clauses,
379 methods,
380 } = d;
381 let _ = generics.drive(&mut graph);
383
384 let _ = parent_clauses.drive(&mut graph);
386 assert!(type_clauses.is_empty());
387
388 let _ = consts.drive(&mut graph);
390 let _ = types.drive(&mut graph);
391 let _ = type_defaults.drive(&mut graph);
392
393 for (_name, gref) in const_defaults {
396 let _ = gref.generics.drive(&mut graph);
397 }
398 for (_, bound_fn) in methods {
399 let id = bound_fn.skip_binder.id;
400 let _ = bound_fn.params.drive(&mut graph);
401 if let Some(decl) = ctx.translated.fun_decls.get(id) {
402 let _ = decl.signature.drive(&mut graph);
403 }
404 }
405 }
406 }
407 graph.unset_current_id();
408 }
409 graph
410}
411
412fn group_declarations_from_scc(
413 _ctx: &TransformCtx,
414 graph: Deps,
415 reordered_sccs: SCCs<AnyTransId>,
416) -> DeclarationsGroups {
417 let reordered_sccs = &reordered_sccs.sccs;
418 let mut reordered_decls: DeclarationsGroups = Vec::new();
419
420 for scc in reordered_sccs.iter() {
422 if scc.is_empty() {
423 continue;
425 }
426
427 let mut it = scc.iter();
429 let id0 = *it.next().unwrap();
430 let decl = graph.graph.get(&id0).unwrap();
431
432 let is_mutually_recursive = scc.len() > 1;
436 let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
437 let is_rec = is_mutually_recursive || is_simply_recursive;
438
439 let all_same_kind = scc
440 .iter()
441 .all(|id| id0.variant_index_arity() == id.variant_index_arity());
442 let ids = scc.iter().copied();
443 let group: DeclarationGroup = match id0 {
444 _ if !all_same_kind => {
445 DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
446 }
447 AnyTransId::Type(_) => {
448 DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
449 }
450 AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
451 AnyTransId::Global(_) => {
452 DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
453 }
454 AnyTransId::TraitDecl(_) => {
455 let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
456 if gr.len() == 1 {
463 DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
464 } else {
465 DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
466 }
467 }
468 AnyTransId::TraitImpl(_) => {
469 DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
470 }
471 };
472
473 reordered_decls.push(group);
474 }
475 reordered_decls
476}
477
478fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
479 trace!();
480
481 let graph = compute_declarations_graph(ctx);
483 trace!("Graph:\n{}\n", graph.fmt_with_ctx(&ctx.into_fmt()));
484
485 let sccs = tarjan_scc(&graph.dgraph);
487
488 let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
495 let all_ids: Vec<AnyTransId> = graph
496 .graph
497 .keys()
498 .copied()
499 .filter(|id| ctx.translated.get_item(*id).is_some())
501 .collect();
502 let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
503
504 let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
506
507 trace!("{:?}", reordered_decls);
508 reordered_decls
509}
510
511pub struct Transform;
512impl TransformPass for Transform {
513 fn transform_ctx(&self, ctx: &mut TransformCtx) {
514 let reordered_decls = compute_reordered_decls(&ctx);
515 ctx.translated.ordered_decls = Some(reordered_decls);
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 #[test]
522 fn test_reorder_sccs1() {
523 use std::collections::BTreeSet as OrdSet;
524 let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
525 let ids = vec![0, 1, 2, 3, 4, 5];
526
527 let get_deps = &|x| match x {
528 0 => vec![3],
529 1 => vec![0, 3],
530 _ => vec![],
531 };
532 let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
533
534 assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
535 assert!(reordered.scc_deps[0] == OrdSet::from([]));
536 assert!(reordered.scc_deps[1] == OrdSet::from([0]));
537 assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
538 }
539}