charon_lib/
common.rs

1use itertools::Itertools;
2
3pub static TAB_INCR: &str = "    ";
4
5/// Custom function to pretty-print elements from an iterator
6/// The output format is:
7/// ```text
8/// [
9///   elem_0,
10///   ...
11///   elem_n
12/// ]
13/// ```
14pub fn pretty_display_list<T>(
15    t_to_string: impl Fn(T) -> String,
16    it: impl IntoIterator<Item = T>,
17) -> String {
18    let mut elems = it
19        .into_iter()
20        .map(t_to_string)
21        .map(|x| format!("  {},\n", x))
22        .peekable();
23    if elems.peek().is_none() {
24        "[]".to_owned()
25    } else {
26        format!("[\n{}]", elems.format(""))
27    }
28}
29
30/// Yield `None` then infinitely many `Some(x)`.
31pub fn repeat_except_first<T: Clone>(x: T) -> impl Iterator<Item = Option<T>> {
32    [None].into_iter().chain(std::iter::repeat(Some(x)))
33}
34
35pub mod type_map {
36    use std::{
37        any::{Any, TypeId},
38        collections::HashMap,
39        marker::PhantomData,
40    };
41
42    pub trait Mappable = Any + Send + Sync;
43
44    pub trait Mapper {
45        type Value<T: Mappable>: Mappable;
46    }
47
48    /// A map that maps types to values in a generic manner: we store for each type `T` a value of
49    /// type `M::Value<T>`.
50    pub struct TypeMap<M> {
51        data: HashMap<TypeId, Box<dyn Mappable>>,
52        phantom: PhantomData<M>,
53    }
54
55    impl<M: Mapper> TypeMap<M> {
56        pub fn get<T: Mappable>(&self) -> Option<&M::Value<T>> {
57            self.data
58                .get(&TypeId::of::<T>())
59                // We must be careful to not accidentally cast the box itself as `dyn Any`.
60                .map(|val: &Box<dyn Mappable>| &**val)
61                .and_then(|val: &dyn Mappable| (val as &dyn Any).downcast_ref())
62        }
63
64        pub fn get_mut<T: Mappable>(&mut self) -> Option<&mut M::Value<T>> {
65            self.data
66                .get_mut(&TypeId::of::<T>())
67                // We must be careful to not accidentally cast the box itself as `dyn Any`.
68                .map(|val: &mut Box<dyn Mappable>| &mut **val)
69                .and_then(|val: &mut dyn Mappable| (val as &mut dyn Any).downcast_mut())
70        }
71
72        pub fn insert<T: Mappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
73            self.data
74                .insert(TypeId::of::<T>(), Box::new(val))
75                .and_then(|val: Box<dyn Mappable>| (val as Box<dyn Any>).downcast().ok())
76        }
77    }
78
79    impl<M> Default for TypeMap<M> {
80        fn default() -> Self {
81            Self {
82                data: Default::default(),
83                phantom: Default::default(),
84            }
85        }
86    }
87}
88
89pub mod hash_consing {
90    use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
91
92    use super::type_map::{Mappable, Mapper, TypeMap};
93    use itertools::Either;
94    use serde::{Deserialize, Serialize};
95    use std::collections::HashMap;
96    use std::hash::Hash;
97    use std::ops::ControlFlow;
98    use std::sync::{Arc, LazyLock, RwLock};
99
100    /// Hash-consed data structure: a reference-counted wrapper that guarantees that two equal
101    /// value will be stored at the same address. This makes it possible to use the pointer address
102    /// as a hash value.
103    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104    pub struct HashConsed<T>(Arc<T>);
105
106    impl<T> HashConsed<T> {
107        pub fn inner(&self) -> &T {
108            self.0.as_ref()
109        }
110    }
111
112    impl<T> HashConsed<T>
113    where
114        T: Hash + PartialEq + Eq + Clone + Mappable,
115    {
116        pub fn new(inner: T) -> Self {
117            Self::intern(Either::Left(inner))
118        }
119
120        /// Clones if needed to get mutable access to the inner value.
121        pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
122            let kind = Arc::make_mut(&mut self.0);
123            let ret = f(kind);
124            // Re-establish sharing, crucial for the hashing function to be correct.
125            *self = Self::intern(Either::Right(self.0.clone()));
126            ret
127        }
128
129        /// Deduplicate the valuess by hashing them. This deduplication is crucial for the hashing
130        /// function to be correct. This is the only function allowed to create `Self` values.
131        fn intern(inner: Either<T, Arc<T>>) -> Self {
132            struct InternMapper;
133            impl Mapper for InternMapper {
134                type Value<T: Mappable> = HashMap<T, Arc<T>>;
135            }
136            static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> =
137                LazyLock::new(|| Default::default());
138
139            if INTERNED.read().unwrap().get::<T>().is_none() {
140                INTERNED.write().unwrap().insert::<T>(Default::default());
141            }
142            let read_guard = INTERNED.read().unwrap();
143            if let Some(inner) = (*read_guard)
144                .get::<T>()
145                .unwrap()
146                .get(inner.as_ref().either(|x| x, |x| x.as_ref()))
147            {
148                Self(inner.clone())
149            } else {
150                drop(read_guard);
151                // We clone the value here in the slow path, which makes it possible to avoid an
152                // allocation in the fast path.
153                let raw_val: T = inner.as_ref().either(T::clone, |x| x.as_ref().clone());
154                let arc: Arc<T> = inner.either(Arc::new, |x| x);
155                INTERNED
156                    .write()
157                    .unwrap()
158                    .get_mut::<T>()
159                    .unwrap()
160                    .insert(raw_val, arc.clone());
161                Self(arc)
162            }
163        }
164    }
165
166    /// Hash the pointer; this is only correct if two identical values of `Self` are guaranteed to
167    /// point to the same memory location, which we carefully enforce above.
168    impl<T> std::hash::Hash for HashConsed<T> {
169        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
170            Arc::as_ptr(&self.0).hash(state);
171        }
172    }
173
174    impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
175        fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
176            v.visit(self.inner())
177        }
178    }
179    /// Note: this explores the inner value mutably by cloning and re-hashing afterwards.
180    impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
181    where
182        T: Hash + PartialEq + Eq + Clone + Mappable,
183        V: for<'a> VisitMut<'a, T>,
184    {
185        fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
186            self.with_inner_mut(|inner| v.visit(inner))
187        }
188    }
189}
190
191pub mod hash_by_addr {
192    use std::{
193        hash::{Hash, Hasher},
194        ops::Deref,
195    };
196
197    /// A wrapper around a smart pointer that hashes and compares the contents by the address of
198    /// the pointee.
199    #[derive(Debug, Clone)]
200    pub struct HashByAddr<T>(pub T);
201
202    impl<T: Deref> HashByAddr<T> {
203        fn addr(&self) -> *const T::Target {
204            self.0.deref()
205        }
206    }
207
208    impl<T: Eq + Deref> Eq for HashByAddr<T> {}
209
210    impl<T: PartialEq + Deref> PartialEq for HashByAddr<T> {
211        fn eq(&self, other: &Self) -> bool {
212            std::ptr::addr_eq(self.addr(), other.addr())
213        }
214    }
215
216    impl<T: Hash + Deref> Hash for HashByAddr<T> {
217        fn hash<H: Hasher>(&self, state: &mut H) {
218            self.addr().hash(state);
219        }
220    }
221}
222
223// This is the amount of bytes that need to be left on the stack before increasing the size. It
224// must be at least as large as the stack required by any code that does not call
225// `ensure_sufficient_stack`.
226const RED_ZONE: usize = 100 * 1024; // 100k
227
228// Only the first stack that is pushed, grows exponentially (2^n * STACK_PER_RECURSION) from then
229// on. Values taken from rustc.
230const STACK_PER_RECURSION: usize = 1024 * 1024; // 1MB
231
232/// Grows the stack on demand to prevent stack overflow. Call this in strategic locations to "break
233/// up" recursive calls. E.g. most statement visitors can benefit from this.
234#[inline]
235pub fn ensure_sufficient_stack<R>(f: impl FnOnce() -> R) -> R {
236    stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f)
237}