1use crate::common::*;
10use crate::formatter::{AstFormatter, IntoFormatter};
11use crate::graphs::*;
12use crate::transform::TransformCtx;
13use crate::ullbc_ast::*;
14use derive_generic_visitor::*;
15use indexmap::{IndexMap, IndexSet};
16use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
17use petgraph::algo::tarjan_scc;
18use petgraph::graphmap::DiGraphMap;
19use serde::{Deserialize, Serialize};
20use std::fmt::{Debug, Display, Error};
21use std::vec::Vec;
22
23use super::ctx::TransformPass;
24
25#[derive(
28 Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
29)]
30#[charon::variants_suffix("Group")]
31pub enum GDeclarationGroup<Id> {
32 NonRec(Id),
34 Rec(Vec<Id>),
36}
37
38#[derive(
40 Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
41)]
42#[charon::variants_suffix("Group")]
43pub enum DeclarationGroup {
44 Type(GDeclarationGroup<TypeDeclId>),
46 Fun(GDeclarationGroup<FunDeclId>),
48 Global(GDeclarationGroup<GlobalDeclId>),
50 TraitDecl(GDeclarationGroup<TraitDeclId>),
52 TraitImpl(GDeclarationGroup<TraitImplId>),
54 Mixed(GDeclarationGroup<AnyTransId>),
56}
57
58impl<Id: Copy> GDeclarationGroup<Id> {
59 pub fn get_ids(&self) -> &[Id] {
60 use GDeclarationGroup::*;
61 match self {
62 NonRec(id) => std::slice::from_ref(id),
63 Rec(ids) => ids.as_slice(),
64 }
65 }
66
67 pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
68 where
69 Id: Into<AnyTransId>,
70 {
71 self.get_ids().iter().copied().map(|id| id.into()).collect()
72 }
73
74 fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
75 where
76 Id: TryFrom<AnyTransId>,
77 Id::Error: Debug,
78 {
79 let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
80 if is_rec {
81 GDeclarationGroup::Rec(gr)
82 } else {
83 assert!(gr.len() == 1);
84 GDeclarationGroup::NonRec(gr[0])
85 }
86 }
87
88 fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
89 where
90 Id: Into<AnyTransId>,
91 {
92 match self {
93 GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
94 GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
95 }
96 }
97}
98
99impl DeclarationGroup {
100 pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
101 use DeclarationGroup::*;
102 match self {
103 Type(gr) => gr.to_mixed(),
104 Fun(gr) => gr.to_mixed(),
105 Global(gr) => gr.to_mixed(),
106 TraitDecl(gr) => gr.to_mixed(),
107 TraitImpl(gr) => gr.to_mixed(),
108 Mixed(gr) => gr.clone(),
109 }
110 }
111
112 pub fn get_ids(&self) -> Vec<AnyTransId> {
113 use DeclarationGroup::*;
114 match self {
115 Type(gr) => gr.get_any_trans_ids(),
116 Fun(gr) => gr.get_any_trans_ids(),
117 Global(gr) => gr.get_any_trans_ids(),
118 TraitDecl(gr) => gr.get_any_trans_ids(),
119 TraitImpl(gr) => gr.get_any_trans_ids(),
120 Mixed(gr) => gr.get_any_trans_ids(),
121 }
122 }
123}
124
125#[derive(Clone, Copy)]
126pub struct DeclInfo {
127 pub is_transparent: bool,
128}
129
130pub type DeclarationsGroups = Vec<DeclarationGroup>;
131
132impl<Id: Display> Display for GDeclarationGroup<Id> {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
134 match self {
135 GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
136 GDeclarationGroup::Rec(ids) => {
137 write!(
138 f,
139 "rec: {}",
140 pretty_display_list(|id| format!(" {id}"), ids)
141 )
142 }
143 }
144 }
145}
146
147impl Display for DeclarationGroup {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
149 match self {
150 DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
151 DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
152 DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
153 DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
154 DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
155 DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
156 }
157 }
158}
159
160#[derive(Visitor)]
161pub struct Deps {
162 dgraph: DiGraphMap<AnyTransId, ()>,
163 graph: IndexMap<AnyTransId, IndexSet<AnyTransId>>,
165 current_id: Option<AnyTransId>,
167 parent_trait_impl: Option<TraitImplId>,
209 parent_trait_decl: Option<TraitDeclId>,
210}
211
212impl Deps {
213 fn new() -> Self {
214 Deps {
215 dgraph: DiGraphMap::new(),
216 graph: IndexMap::new(),
217 current_id: None,
218 parent_trait_impl: None,
219 parent_trait_decl: None,
220 }
221 }
222
223 fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
224 match kind {
225 ItemKind::Regular => {}
226 ItemKind::TraitDecl { trait_ref, .. } => {
227 self.parent_trait_decl = Some(trait_ref.trait_id)
228 }
229 ItemKind::TraitImpl { impl_ref, .. } => self.parent_trait_impl = Some(impl_ref.impl_id),
230 }
231 }
232 fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
233 self.insert_node(id);
234 self.current_id = Some(id);
235
236 use AnyTransId::*;
238 match id {
239 TraitDecl(_) | TraitImpl(_) | Type(_) => (),
240 Global(id) => {
241 if let Some(decl) = ctx.translated.global_decls.get(id) {
242 self.set_impl_or_trait_id(&decl.kind);
243 }
244 }
245 Fun(id) => {
246 if let Some(decl) = ctx.translated.fun_decls.get(id) {
247 self.set_impl_or_trait_id(&decl.kind);
248 }
249 }
250 }
251 }
252
253 fn unset_current_id(&mut self) {
254 self.current_id = None;
255 self.parent_trait_impl = None;
256 self.parent_trait_decl = None;
257 }
258
259 fn insert_node(&mut self, id: AnyTransId) {
260 if !self.dgraph.contains_node(id) {
262 self.dgraph.add_node(id);
263 assert!(!self.graph.contains_key(&id));
264 self.graph.insert(id, IndexSet::new());
265 }
266 }
267
268 fn insert_edge(&mut self, id1: AnyTransId) {
269 let id0 = self.current_id.unwrap();
270 self.insert_node(id1);
271 if !self.dgraph.contains_edge(id0, id1) {
272 self.dgraph.add_edge(id0, id1, ());
273 self.graph.get_mut(&id0).unwrap().insert(id1);
274 }
275 }
276}
277
278impl VisitAst for Deps {
279 fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
280 self.insert_edge((*id).into());
281 }
282
283 fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
284 self.insert_edge((*id).into());
285 }
286
287 fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
288 if let Some(impl_id) = &self.parent_trait_impl
293 && impl_id == id
294 {
295 return;
296 }
297 self.insert_edge((*id).into());
298 }
299
300 fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
301 if let Some(trait_id) = &self.parent_trait_decl
305 && trait_id == id
306 {
307 return;
308 }
309 self.insert_edge((*id).into());
310 }
311
312 fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
313 self.insert_edge((*id).into());
314 }
315
316 fn visit_item_meta(&mut self, _: &ItemMeta) -> ControlFlow<Self::Break> {
317 Continue(())
319 }
320 fn visit_item_kind(&mut self, _: &ItemKind) -> ControlFlow<Self::Break> {
321 Continue(())
324 }
325}
326
327impl AnyTransId {
328 fn fmt_with_ctx(&self, ctx: &TransformCtx) -> String {
329 use AnyTransId::*;
330 let ctx = ctx.into_fmt();
331 match self {
332 Type(id) => ctx.format_object(*id),
333 Fun(id) => ctx.format_object(*id),
334 Global(id) => ctx.format_object(*id),
335 TraitDecl(id) => ctx.format_object(*id),
336 TraitImpl(id) => ctx.format_object(*id),
337 }
338 }
339}
340
341impl Deps {
342 fn fmt_with_ctx(&self, ctx: &TransformCtx) -> String {
343 self.dgraph
344 .nodes()
345 .map(|node| {
346 let edges = self
347 .dgraph
348 .edges(node)
349 .map(|e| format!("\n {}", e.1.fmt_with_ctx(ctx)))
350 .collect::<Vec<String>>()
351 .join(",");
352
353 format!("{} -> [{}\n]", node.fmt_with_ctx(ctx), edges)
354 })
355 .collect::<Vec<String>>()
356 .join(",\n")
357 }
358}
359
360fn compute_declarations_graph<'tcx>(ctx: &'tcx TransformCtx) -> Deps {
361 let mut graph = Deps::new();
362 for (id, item) in ctx.translated.all_items_with_ids() {
363 graph.set_current_id(ctx, id);
364 match item {
365 AnyTransItem::Type(..) | AnyTransItem::TraitImpl(..) | AnyTransItem::Global(..) => {
366 let _ = item.drive(&mut graph);
367 }
368 AnyTransItem::Fun(d) => {
369 let _ = d.signature.drive(&mut graph);
372 let _ = d.body.drive(&mut graph);
373 if let ItemKind::TraitDecl { trait_ref, .. } = &d.kind {
377 graph.insert_edge(trait_ref.trait_id.into());
378 }
379 }
380 AnyTransItem::TraitDecl(d) => {
381 let TraitDecl {
382 def_id: _,
383 item_meta: _,
384 generics,
385 parent_clauses,
386 consts,
387 const_defaults,
388 types,
389 type_defaults,
390 type_clauses,
391 methods,
392 } = d;
393 let _ = generics.drive(&mut graph);
395
396 let _ = parent_clauses.drive(&mut graph);
398 assert!(type_clauses.is_empty());
399
400 let _ = consts.drive(&mut graph);
402 let _ = types.drive(&mut graph);
403 let _ = type_defaults.drive(&mut graph);
404
405 for (_name, gref) in const_defaults {
408 let _ = gref.generics.drive(&mut graph);
409 }
410 for (_, bound_fn) in methods {
411 let id = bound_fn.skip_binder.id;
412 let _ = bound_fn.params.drive(&mut graph);
413 if let Some(decl) = ctx.translated.fun_decls.get(id) {
414 let _ = decl.signature.drive(&mut graph);
415 }
416 }
417 }
418 }
419 graph.unset_current_id();
420 }
421 graph
422}
423
424fn group_declarations_from_scc(
425 _ctx: &TransformCtx,
426 graph: Deps,
427 reordered_sccs: SCCs<AnyTransId>,
428) -> DeclarationsGroups {
429 let reordered_sccs = &reordered_sccs.sccs;
430 let mut reordered_decls: DeclarationsGroups = Vec::new();
431
432 for scc in reordered_sccs.iter() {
434 if scc.is_empty() {
435 continue;
437 }
438
439 let mut it = scc.iter();
441 let id0 = *it.next().unwrap();
442 let decl = graph.graph.get(&id0).unwrap();
443
444 let is_mutually_recursive = scc.len() > 1;
448 let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
449 let is_rec = is_mutually_recursive || is_simply_recursive;
450
451 let all_same_kind = scc
452 .iter()
453 .all(|id| id0.variant_index_arity() == id.variant_index_arity());
454 let ids = scc.iter().copied();
455 let group: DeclarationGroup = match id0 {
456 _ if !all_same_kind => {
457 DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
458 }
459 AnyTransId::Type(_) => {
460 DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
461 }
462 AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
463 AnyTransId::Global(_) => {
464 DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
465 }
466 AnyTransId::TraitDecl(_) => {
467 let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
468 if gr.len() == 1 {
475 DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
476 } else {
477 DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
478 }
479 }
480 AnyTransId::TraitImpl(_) => {
481 DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
482 }
483 };
484
485 reordered_decls.push(group);
486 }
487 reordered_decls
488}
489
490fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
491 trace!();
492
493 let graph = compute_declarations_graph(ctx);
495 trace!("Graph:\n{}\n", graph.fmt_with_ctx(ctx));
496
497 let sccs = tarjan_scc(&graph.dgraph);
499
500 let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
507 let all_ids: Vec<AnyTransId> = graph
508 .graph
509 .keys()
510 .copied()
511 .filter(|id| ctx.translated.get_item(*id).is_some())
513 .collect();
514 let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
515
516 let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
518
519 trace!("{:?}", reordered_decls);
520 reordered_decls
521}
522
523pub struct Transform;
524impl TransformPass for Transform {
525 fn transform_ctx(&self, ctx: &mut TransformCtx) {
526 let reordered_decls = compute_reordered_decls(&ctx);
527 ctx.translated.ordered_decls = Some(reordered_decls);
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 #[test]
534 fn test_reorder_sccs1() {
535 use std::collections::BTreeSet as OrdSet;
536 let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
537 let ids = vec![0, 1, 2, 3, 4, 5];
538
539 let get_deps = &|x| match x {
540 0 => vec![3],
541 1 => vec![0, 3],
542 _ => vec![],
543 };
544 let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
545
546 assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
547 assert!(reordered.scc_deps[0] == OrdSet::from([]));
548 assert!(reordered.scc_deps[1] == OrdSet::from([0]));
549 assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
550 }
551}