charon_lib/
common.rs

1use itertools::Itertools;
2
3pub static TAB_INCR: &str = "    ";
4
5/// Custom function to pretty-print elements from an iterator
6/// The output format is:
7/// ```text
8/// [
9///   elem_0,
10///   ...
11///   elem_n
12/// ]
13/// ```
14pub fn pretty_display_list<T>(
15    t_to_string: impl Fn(T) -> String,
16    it: impl IntoIterator<Item = T>,
17) -> String {
18    let mut elems = it
19        .into_iter()
20        .map(t_to_string)
21        .map(|x| format!("  {},\n", x))
22        .peekable();
23    if elems.peek().is_none() {
24        "[]".to_owned()
25    } else {
26        format!("[\n{}]", elems.format(""))
27    }
28}
29
30/// Implement `From` and `TryFrom` to wrap/unwrap enum variants with a single payload.
31#[macro_export]
32macro_rules! impl_from_enum {
33    ($enum:ident::$variant:ident($ty:ty)) => {
34        impl From<$ty> for $enum {
35            fn from(x: $ty) -> Self {
36                $enum::$variant(x)
37            }
38        }
39        impl TryFrom<$enum> for $ty {
40            type Error = ();
41            fn try_from(e: $enum) -> Result<Self, Self::Error> {
42                match e {
43                    $enum::$variant(x) => Ok(x),
44                    _ => Err(()),
45                }
46            }
47        }
48    };
49}
50
51/// Yield `None` then infinitely many `Some(x)`.
52pub fn repeat_except_first<T: Clone>(x: T) -> impl Iterator<Item = Option<T>> {
53    [None].into_iter().chain(std::iter::repeat(Some(x)))
54}
55
56pub mod type_map {
57    use std::{
58        any::{Any, TypeId},
59        collections::HashMap,
60        marker::PhantomData,
61    };
62
63    pub trait Mappable = Any + Send + Sync;
64
65    pub trait Mapper {
66        type Value<T: Mappable>: Mappable;
67    }
68
69    /// A map that maps types to values in a generic manner: we store for each type `T` a value of
70    /// type `M::Value<T>`.
71    pub struct TypeMap<M> {
72        data: HashMap<TypeId, Box<dyn Mappable>>,
73        phantom: PhantomData<M>,
74    }
75
76    impl<M: Mapper> TypeMap<M> {
77        pub fn get<T: Mappable>(&self) -> Option<&M::Value<T>> {
78            self.data
79                .get(&TypeId::of::<T>())
80                // We must be careful to not accidentally cast the box itself as `dyn Any`.
81                .map(|val: &Box<dyn Mappable>| &**val)
82                .and_then(|val: &dyn Mappable| (val as &dyn Any).downcast_ref())
83        }
84
85        pub fn get_mut<T: Mappable>(&mut self) -> Option<&mut M::Value<T>> {
86            self.data
87                .get_mut(&TypeId::of::<T>())
88                // We must be careful to not accidentally cast the box itself as `dyn Any`.
89                .map(|val: &mut Box<dyn Mappable>| &mut **val)
90                .and_then(|val: &mut dyn Mappable| (val as &mut dyn Any).downcast_mut())
91        }
92
93        pub fn insert<T: Mappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
94            self.data
95                .insert(TypeId::of::<T>(), Box::new(val))
96                .and_then(|val: Box<dyn Mappable>| (val as Box<dyn Any>).downcast().ok())
97        }
98    }
99
100    impl<M> Default for TypeMap<M> {
101        fn default() -> Self {
102            Self {
103                data: Default::default(),
104                phantom: Default::default(),
105            }
106        }
107    }
108}
109
110pub mod hash_consing {
111    use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
112
113    use super::hash_by_addr::HashByAddr;
114    use super::type_map::{Mappable, Mapper, TypeMap};
115    use serde::{Deserialize, Serialize};
116    use std::collections::HashSet;
117    use std::hash::Hash;
118    use std::ops::ControlFlow;
119    use std::sync::{Arc, LazyLock, RwLock};
120
121    /// Hash-consed data structure: a reference-counted wrapper that guarantees that two equal
122    /// value will be stored at the same address. This makes it possible to use the pointer address
123    /// as a hash value.
124    #[derive(Clone, PartialEq, Eq, Hash, Serialize)]
125    pub struct HashConsed<T>(HashByAddr<Arc<T>>);
126
127    impl<T> HashConsed<T> {
128        pub fn inner(&self) -> &T {
129            self.0.0.as_ref()
130        }
131    }
132
133    impl<T> HashConsed<T>
134    where
135        T: Hash + PartialEq + Eq + Clone + Mappable,
136    {
137        pub fn new(inner: T) -> Self {
138            Self::intern(inner)
139        }
140
141        /// Clones if needed to get mutable access to the inner value.
142        pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
143            // The value is behind a shared `Arc`, we clone it in order to mutate it.
144            let mut value = self.inner().clone();
145            let ret = f(&mut value);
146            // Re-intern the new value.
147            *self = Self::intern(value);
148            ret
149        }
150
151        /// Deduplicate the values by hashing them. This deduplication is crucial for the hashing
152        /// function to be correct. This is the only function allowed to create `Self` values.
153        fn intern(inner: T) -> Self {
154            struct InternMapper;
155            impl Mapper for InternMapper {
156                type Value<T: Mappable> = HashSet<Arc<T>>;
157            }
158            // This is a static mutable `HashSet<Arc<T>>` that records for each `T` value a unique
159            // `Arc<T>` that contains the same value. Values inside the set are hashed/compared
160            // as is normal for `T`.
161            // Once we've gotten an `Arc` out of the set however, we're sure that "T-equality"
162            // implies address-equality, hence the `HashByAddr` wrapper preserves correct equality
163            // and hashing behavior.
164            static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> =
165                LazyLock::new(|| Default::default());
166
167            if INTERNED.read().unwrap().get::<T>().is_none() {
168                INTERNED.write().unwrap().insert::<T>(HashSet::default());
169            }
170            let read_guard = INTERNED.read().unwrap();
171            let arc = if let Some(arc) = (*read_guard).get::<T>().unwrap().get(&inner) {
172                arc.clone()
173            } else {
174                drop(read_guard);
175                let arc: Arc<T> = Arc::new(inner);
176                INTERNED
177                    .write()
178                    .unwrap()
179                    .get_mut::<T>()
180                    .unwrap()
181                    .insert(arc.clone());
182                arc
183            };
184            Self(HashByAddr(arc))
185        }
186    }
187
188    impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
189        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190            // Hide the `HashByAddr` wrapper.
191            f.debug_tuple("HashConsed").field(self.inner()).finish()
192        }
193    }
194
195    /// Manual impl to make sure we re-establish sharing!
196    impl<'de, T> Deserialize<'de> for HashConsed<T>
197    where
198        T: Hash + PartialEq + Eq + Clone + Mappable,
199        T: Deserialize<'de>,
200    {
201        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
202        where
203            D: serde::Deserializer<'de>,
204        {
205            let x: T = T::deserialize(deserializer)?;
206            Ok(Self::new(x))
207        }
208    }
209
210    impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
211        fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
212            v.visit(self.inner())
213        }
214    }
215    /// Note: this explores the inner value mutably by cloning and re-hashing afterwards.
216    impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
217    where
218        T: Hash + PartialEq + Eq + Clone + Mappable,
219        V: for<'a> VisitMut<'a, T>,
220    {
221        fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
222            self.with_inner_mut(|inner| v.visit(inner))
223        }
224    }
225
226    #[test]
227    fn test_hash_cons() {
228        let x = HashConsed::new(42u32);
229        let y = HashConsed::new(42u32);
230        assert_eq!(x, y);
231        let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
232        assert_eq!(x, z);
233    }
234}
235
236pub mod hash_by_addr {
237    use serde::{Deserialize, Serialize};
238    use std::{
239        hash::{Hash, Hasher},
240        ops::Deref,
241    };
242
243    /// A wrapper around a smart pointer that hashes and compares the contents by the address of
244    /// the pointee.
245    #[derive(Debug, Clone, Serialize, Deserialize)]
246    pub struct HashByAddr<T>(pub T);
247
248    impl<T: Deref> HashByAddr<T> {
249        fn addr(&self) -> *const T::Target {
250            self.0.deref()
251        }
252    }
253
254    impl<T: Eq + Deref> Eq for HashByAddr<T> {}
255
256    impl<T: PartialEq + Deref> PartialEq for HashByAddr<T> {
257        fn eq(&self, other: &Self) -> bool {
258            std::ptr::addr_eq(self.addr(), other.addr())
259        }
260    }
261
262    impl<T: Hash + Deref> Hash for HashByAddr<T> {
263        fn hash<H: Hasher>(&self, state: &mut H) {
264            self.addr().hash(state);
265        }
266    }
267}
268
269// This is the amount of bytes that need to be left on the stack before increasing the size. It
270// must be at least as large as the stack required by any code that does not call
271// `ensure_sufficient_stack`.
272const RED_ZONE: usize = 100 * 1024; // 100k
273
274// Only the first stack that is pushed, grows exponentially (2^n * STACK_PER_RECURSION) from then
275// on. Values taken from rustc.
276const STACK_PER_RECURSION: usize = 1024 * 1024; // 1MB
277
278/// Grows the stack on demand to prevent stack overflow. Call this in strategic locations to "break
279/// up" recursive calls. E.g. most statement visitors can benefit from this.
280#[inline]
281pub fn ensure_sufficient_stack<R>(f: impl FnOnce() -> R) -> R {
282    stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f)
283}