use crate::common::*;
use crate::formatter::{AstFormatter, IntoFormatter};
use crate::graphs::*;
use crate::transform::TransformCtx;
use crate::ullbc_ast::*;
use derive_visitor::{Drive, Visitor};
use hashlink::{LinkedHashMap, LinkedHashSet};
use macros::{EnumAsGetters, EnumIsA, VariantIndexArity, VariantName};
use petgraph::algo::tarjan_scc;
use petgraph::graphmap::DiGraphMap;
use serde::{Deserialize, Serialize};
use std::fmt::{Debug, Display, Error};
use std::vec::Vec;
#[derive(
Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
)]
#[charon::variants_suffix("Group")]
pub enum GDeclarationGroup<Id> {
NonRec(Id),
Rec(Vec<Id>),
}
#[derive(
Debug, Clone, VariantIndexArity, VariantName, EnumAsGetters, EnumIsA, Serialize, Deserialize,
)]
#[charon::variants_suffix("Group")]
pub enum DeclarationGroup {
Type(GDeclarationGroup<TypeDeclId>),
Fun(GDeclarationGroup<FunDeclId>),
Global(GDeclarationGroup<GlobalDeclId>),
TraitDecl(GDeclarationGroup<TraitDeclId>),
TraitImpl(GDeclarationGroup<TraitImplId>),
Mixed(GDeclarationGroup<AnyTransId>),
}
impl<Id: Copy> GDeclarationGroup<Id> {
pub fn get_ids(&self) -> &[Id] {
use GDeclarationGroup::*;
match self {
NonRec(id) => std::slice::from_ref(id),
Rec(ids) => ids.as_slice(),
}
}
pub fn get_any_trans_ids(&self) -> Vec<AnyTransId>
where
Id: Into<AnyTransId>,
{
self.get_ids().iter().copied().map(|id| id.into()).collect()
}
fn make_group(is_rec: bool, gr: impl Iterator<Item = AnyTransId>) -> Self
where
Id: TryFrom<AnyTransId>,
Id::Error: Debug,
{
let gr: Vec<_> = gr.map(|x| x.try_into().unwrap()).collect();
if is_rec {
GDeclarationGroup::Rec(gr)
} else {
assert!(gr.len() == 1);
GDeclarationGroup::NonRec(gr[0])
}
}
fn to_mixed(&self) -> GDeclarationGroup<AnyTransId>
where
Id: Into<AnyTransId>,
{
match self {
GDeclarationGroup::NonRec(x) => GDeclarationGroup::NonRec((*x).into()),
GDeclarationGroup::Rec(_) => GDeclarationGroup::Rec(self.get_any_trans_ids()),
}
}
}
impl DeclarationGroup {
pub fn to_mixed_group(&self) -> GDeclarationGroup<AnyTransId> {
use DeclarationGroup::*;
match self {
Type(gr) => gr.to_mixed(),
Fun(gr) => gr.to_mixed(),
Global(gr) => gr.to_mixed(),
TraitDecl(gr) => gr.to_mixed(),
TraitImpl(gr) => gr.to_mixed(),
Mixed(gr) => gr.clone(),
}
}
pub fn get_ids(&self) -> Vec<AnyTransId> {
use DeclarationGroup::*;
match self {
Type(gr) => gr.get_any_trans_ids(),
Fun(gr) => gr.get_any_trans_ids(),
Global(gr) => gr.get_any_trans_ids(),
TraitDecl(gr) => gr.get_any_trans_ids(),
TraitImpl(gr) => gr.get_any_trans_ids(),
Mixed(gr) => gr.get_any_trans_ids(),
}
}
}
#[derive(Clone, Copy)]
pub struct DeclInfo {
pub is_transparent: bool,
}
pub type DeclarationsGroups = Vec<DeclarationGroup>;
impl<Id: Display> Display for GDeclarationGroup<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
match self {
GDeclarationGroup::NonRec(id) => write!(f, "non-rec: {id}"),
GDeclarationGroup::Rec(ids) => {
write!(
f,
"rec: {}",
pretty_display_list(|id| format!(" {id}"), ids)
)
}
}
}
}
impl Display for DeclarationGroup {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), Error> {
match self {
DeclarationGroup::Type(decl) => write!(f, "{{ Type(s): {decl} }}"),
DeclarationGroup::Fun(decl) => write!(f, "{{ Fun(s): {decl} }}"),
DeclarationGroup::Global(decl) => write!(f, "{{ Global(s): {decl} }}"),
DeclarationGroup::TraitDecl(decl) => write!(f, "{{ Trait decls(s): {decl} }}"),
DeclarationGroup::TraitImpl(decl) => write!(f, "{{ Trait impl(s): {decl} }}"),
DeclarationGroup::Mixed(decl) => write!(f, "{{ Mixed items: {decl} }}"),
}
}
}
#[derive(Visitor)]
#[visitor(
TypeDeclId(enter),
FunDeclId(enter),
GlobalDeclId(enter),
TraitImplId(enter),
TraitDeclId(enter),
BodyId(enter),
Ty(enter)
)]
pub struct Deps<'tcx, 'ctx> {
ctx: &'tcx TransformCtx<'ctx>,
dgraph: DiGraphMap<AnyTransId, ()>,
graph: LinkedHashMap<AnyTransId, LinkedHashSet<AnyTransId>>,
current_id: Option<AnyTransId>,
parent_trait_impl: Option<TraitImplId>,
parent_trait_decl: Option<TraitDeclId>,
}
impl<'tcx, 'ctx> Deps<'tcx, 'ctx> {
fn new(ctx: &'tcx TransformCtx<'ctx>) -> Self {
Deps {
ctx,
dgraph: DiGraphMap::new(),
graph: LinkedHashMap::new(),
current_id: None,
parent_trait_impl: None,
parent_trait_decl: None,
}
}
fn set_impl_or_trait_id(&mut self, kind: &ItemKind) {
match kind {
ItemKind::Regular => {}
ItemKind::TraitDecl { trait_id, .. } => self.parent_trait_decl = Some(*trait_id),
ItemKind::TraitImpl { impl_id, .. } => self.parent_trait_impl = Some(*impl_id),
}
}
fn set_current_id(&mut self, ctx: &TransformCtx, id: AnyTransId) {
self.insert_node(id);
self.current_id = Some(id);
use AnyTransId::*;
match id {
TraitDecl(_) | TraitImpl(_) | Type(_) => (),
Global(id) => {
if let Some(decl) = ctx.translated.global_decls.get(id) {
self.set_impl_or_trait_id(&decl.kind);
}
}
Fun(id) => {
if let Some(decl) = ctx.translated.fun_decls.get(id) {
self.set_impl_or_trait_id(&decl.kind);
}
}
}
}
fn unset_current_id(&mut self) {
self.current_id = None;
self.parent_trait_impl = None;
self.parent_trait_decl = None;
}
fn insert_node(&mut self, id: AnyTransId) {
if !self.dgraph.contains_node(id) {
self.dgraph.add_node(id);
assert!(!self.graph.contains_key(&id));
self.graph.insert(id, LinkedHashSet::new());
}
}
fn insert_edge(&mut self, id1: AnyTransId) {
let id0 = self.current_id.unwrap();
self.insert_node(id1);
if !self.dgraph.contains_edge(id0, id1) {
self.dgraph.add_edge(id0, id1, ());
self.graph.get_mut(&id0).unwrap().insert(id1);
}
}
}
impl Deps<'_, '_> {
fn enter_type_decl_id(&mut self, id: &TypeDeclId) {
self.insert_edge((*id).into());
}
fn enter_global_decl_id(&mut self, id: &GlobalDeclId) {
self.insert_edge((*id).into());
}
fn enter_trait_impl_id(&mut self, id: &TraitImplId) {
if let Some(impl_id) = &self.parent_trait_impl
&& impl_id == id
{
return;
}
self.insert_edge((*id).into());
}
fn enter_trait_decl_id(&mut self, id: &TraitDeclId) {
if let Some(trait_id) = &self.parent_trait_decl
&& trait_id == id
{
return;
}
self.insert_edge((*id).into());
}
fn enter_fun_decl_id(&mut self, id: &FunDeclId) {
self.insert_edge((*id).into());
}
fn enter_body_id(&mut self, id: &BodyId) {
if let Some(body) = self.ctx.translated.bodies.get(*id) {
body.drive(self);
}
}
fn enter_ty(&mut self, ty: &Ty) {
ty.drive_inner(self);
}
}
impl AnyTransId {
fn fmt_with_ctx(&self, ctx: &TransformCtx) -> String {
use AnyTransId::*;
let ctx = ctx.into_fmt();
match self {
Type(id) => ctx.format_object(*id),
Fun(id) => ctx.format_object(*id),
Global(id) => ctx.format_object(*id),
TraitDecl(id) => ctx.format_object(*id),
TraitImpl(id) => ctx.format_object(*id),
}
}
}
impl Deps<'_, '_> {
fn fmt_with_ctx(&self, ctx: &TransformCtx) -> String {
self.dgraph
.nodes()
.map(|node| {
let edges = self
.dgraph
.edges(node)
.map(|e| format!("\n {}", e.1.fmt_with_ctx(ctx)))
.collect::<Vec<String>>()
.join(",");
format!("{} -> [{}\n]", node.fmt_with_ctx(ctx), edges)
})
.collect::<Vec<String>>()
.join(",\n")
}
}
fn compute_declarations_graph<'tcx, 'ctx>(ctx: &'tcx TransformCtx<'ctx>) -> Deps<'tcx, 'ctx> {
let mut graph = Deps::new(ctx);
for (id, item) in ctx.translated.all_items_with_ids() {
graph.set_current_id(ctx, id);
match item {
AnyTransItem::Type(d) => {
d.drive(&mut graph);
}
AnyTransItem::Fun(d) => {
d.signature.drive(&mut graph);
d.body.drive(&mut graph);
}
AnyTransItem::Global(d) => {
d.drive(&mut graph);
}
AnyTransItem::TraitDecl(d) => {
let TraitDecl {
def_id: _,
item_meta: _,
generics,
parent_clauses,
consts,
const_defaults,
types,
type_defaults,
type_clauses,
required_methods,
provided_methods,
} = d;
generics.drive(&mut graph);
parent_clauses.drive(&mut graph);
assert!(type_clauses.is_empty());
consts.drive(&mut graph);
types.drive(&mut graph);
const_defaults.drive(&mut graph);
type_defaults.drive(&mut graph);
let method_ids = required_methods
.iter()
.chain(provided_methods.iter())
.map(|(_, id)| id)
.copied();
for id in method_ids {
if let Some(decl) = ctx.translated.fun_decls.get(id) {
decl.signature.drive(&mut graph);
}
}
}
AnyTransItem::TraitImpl(d) => {
let TraitImpl {
def_id: _,
item_meta: _,
impl_trait,
generics,
parent_trait_refs,
consts,
types,
type_clauses,
required_methods,
provided_methods,
} = d;
impl_trait.drive(&mut graph);
generics.drive(&mut graph);
parent_trait_refs.drive(&mut graph);
consts.drive(&mut graph);
types.drive(&mut graph);
type_clauses.drive(&mut graph);
required_methods.drive(&mut graph);
provided_methods.drive(&mut graph);
}
}
graph.unset_current_id();
}
graph
}
fn group_declarations_from_scc(
_ctx: &TransformCtx,
graph: Deps<'_, '_>,
reordered_sccs: SCCs<AnyTransId>,
) -> DeclarationsGroups {
let reordered_sccs = &reordered_sccs.sccs;
let mut reordered_decls: DeclarationsGroups = Vec::new();
for scc in reordered_sccs.iter() {
if scc.is_empty() {
continue;
}
let mut it = scc.iter();
let id0 = *it.next().unwrap();
let decl = graph.graph.get(&id0).unwrap();
let is_mutually_recursive = scc.len() > 1;
let is_simply_recursive = !is_mutually_recursive && decl.contains(&id0);
let is_rec = is_mutually_recursive || is_simply_recursive;
let all_same_kind = scc
.iter()
.all(|id| id0.variant_index_arity() == id.variant_index_arity());
let ids = scc.iter().copied();
let group: DeclarationGroup = match id0 {
_ if !all_same_kind => {
DeclarationGroup::Mixed(GDeclarationGroup::make_group(is_rec, ids))
}
AnyTransId::Type(_) => {
DeclarationGroup::Type(GDeclarationGroup::make_group(is_rec, ids))
}
AnyTransId::Fun(_) => DeclarationGroup::Fun(GDeclarationGroup::make_group(is_rec, ids)),
AnyTransId::Global(_) => {
DeclarationGroup::Global(GDeclarationGroup::make_group(is_rec, ids))
}
AnyTransId::TraitDecl(_) => {
let gr: Vec<_> = ids.map(|x| x.try_into().unwrap()).collect();
if gr.len() == 1 {
DeclarationGroup::TraitDecl(GDeclarationGroup::NonRec(gr[0]))
} else {
DeclarationGroup::TraitDecl(GDeclarationGroup::Rec(gr))
}
}
AnyTransId::TraitImpl(_) => {
DeclarationGroup::TraitImpl(GDeclarationGroup::make_group(is_rec, ids))
}
};
reordered_decls.push(group);
}
reordered_decls
}
pub fn compute_reordered_decls(ctx: &TransformCtx) -> DeclarationsGroups {
trace!();
let graph = compute_declarations_graph(ctx);
trace!("Graph:\n{}\n", graph.fmt_with_ctx(ctx));
let sccs = tarjan_scc(&graph.dgraph);
let get_id_dependencies = &|id| graph.graph.get(&id).unwrap().iter().copied().collect();
let all_ids: Vec<AnyTransId> = graph
.graph
.keys()
.copied()
.filter(|id| ctx.translated.get_item(*id).is_some())
.collect();
let reordered_sccs = reorder_sccs::<AnyTransId>(get_id_dependencies, &all_ids, &sccs);
let reordered_decls = group_declarations_from_scc(ctx, graph, reordered_sccs);
trace!("{:?}", reordered_decls);
reordered_decls
}
#[cfg(test)]
mod tests {
#[test]
fn test_reorder_sccs1() {
use std::collections::BTreeSet as OrdSet;
let sccs = vec![vec![0], vec![1, 2], vec![3, 4, 5]];
let ids = vec![0, 1, 2, 3, 4, 5];
let get_deps = &|x| match x {
0 => vec![3],
1 => vec![0, 3],
_ => vec![],
};
let reordered = crate::reorder_decls::reorder_sccs(get_deps, &ids, &sccs);
assert!(reordered.sccs == vec![vec![3, 4, 5], vec![0], vec![1, 2],]);
assert!(reordered.scc_deps[0] == OrdSet::from([]));
assert!(reordered.scc_deps[1] == OrdSet::from([0]));
assert!(reordered.scc_deps[2] == OrdSet::from([0, 1]));
}
}