charon_lib/
common.rs

1use itertools::Itertools;
2
3pub static TAB_INCR: &str = "    ";
4
5/// Custom function to pretty-print elements from an iterator
6/// The output format is:
7/// ```text
8/// [
9///   elem_0,
10///   ...
11///   elem_n
12/// ]
13/// ```
14pub fn pretty_display_list<T>(
15    t_to_string: impl Fn(T) -> String,
16    it: impl IntoIterator<Item = T>,
17) -> String {
18    let mut elems = it
19        .into_iter()
20        .map(t_to_string)
21        .map(|x| format!("  {},\n", x))
22        .peekable();
23    if elems.peek().is_none() {
24        "[]".to_owned()
25    } else {
26        format!("[\n{}]", elems.format(""))
27    }
28}
29
30/// Implement `From` and `TryFrom` to wrap/unwrap enum variants with a single payload.
31#[macro_export]
32macro_rules! impl_from_enum {
33    ($enum:ident::$variant:ident($ty:ty)) => {
34        impl From<$ty> for $enum {
35            fn from(x: $ty) -> Self {
36                $enum::$variant(x)
37            }
38        }
39        impl TryFrom<$enum> for $ty {
40            type Error = ();
41            fn try_from(e: $enum) -> Result<Self, Self::Error> {
42                match e {
43                    $enum::$variant(x) => Ok(x),
44                    _ => Err(()),
45                }
46            }
47        }
48    };
49}
50
51/// Yield `None` then infinitely many `Some(x)`.
52pub fn repeat_except_first<T: Clone>(x: T) -> impl Iterator<Item = Option<T>> {
53    [None].into_iter().chain(std::iter::repeat(Some(x)))
54}
55
56pub mod type_map {
57    use std::{
58        any::{Any, TypeId},
59        collections::HashMap,
60        marker::PhantomData,
61    };
62
63    pub trait Mappable = Any + Send + Sync;
64
65    pub trait Mapper {
66        type Value<T: Mappable>: Mappable;
67    }
68
69    /// A map that maps types to values in a generic manner: we store for each type `T` a value of
70    /// type `M::Value<T>`.
71    pub struct TypeMap<M> {
72        data: HashMap<TypeId, Box<dyn Mappable>>,
73        phantom: PhantomData<M>,
74    }
75
76    impl<M: Mapper> TypeMap<M> {
77        pub fn get<T: Mappable>(&self) -> Option<&M::Value<T>> {
78            self.data
79                .get(&TypeId::of::<T>())
80                // We must be careful to not accidentally cast the box itself as `dyn Any`.
81                .map(|val: &Box<dyn Mappable>| &**val)
82                .and_then(|val: &dyn Mappable| (val as &dyn Any).downcast_ref())
83        }
84
85        pub fn get_mut<T: Mappable>(&mut self) -> Option<&mut M::Value<T>> {
86            self.data
87                .get_mut(&TypeId::of::<T>())
88                // We must be careful to not accidentally cast the box itself as `dyn Any`.
89                .map(|val: &mut Box<dyn Mappable>| &mut **val)
90                .and_then(|val: &mut dyn Mappable| (val as &mut dyn Any).downcast_mut())
91        }
92
93        pub fn insert<T: Mappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
94            self.data
95                .insert(TypeId::of::<T>(), Box::new(val))
96                .and_then(|val: Box<dyn Mappable>| (val as Box<dyn Any>).downcast().ok())
97        }
98    }
99
100    impl<M> Default for TypeMap<M> {
101        fn default() -> Self {
102            Self {
103                data: Default::default(),
104                phantom: Default::default(),
105            }
106        }
107    }
108}
109
110pub mod hash_consing {
111    use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
112
113    use super::type_map::{Mappable, Mapper, TypeMap};
114    use itertools::Either;
115    use serde::{Deserialize, Serialize};
116    use std::collections::HashMap;
117    use std::hash::Hash;
118    use std::ops::ControlFlow;
119    use std::sync::{Arc, LazyLock, RwLock};
120
121    /// Hash-consed data structure: a reference-counted wrapper that guarantees that two equal
122    /// value will be stored at the same address. This makes it possible to use the pointer address
123    /// as a hash value.
124    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
125    pub struct HashConsed<T>(Arc<T>);
126
127    impl<T> HashConsed<T> {
128        pub fn inner(&self) -> &T {
129            self.0.as_ref()
130        }
131    }
132
133    impl<T> HashConsed<T>
134    where
135        T: Hash + PartialEq + Eq + Clone + Mappable,
136    {
137        pub fn new(inner: T) -> Self {
138            Self::intern(Either::Left(inner))
139        }
140
141        /// Clones if needed to get mutable access to the inner value.
142        pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
143            let kind = Arc::make_mut(&mut self.0);
144            let ret = f(kind);
145            // Re-establish sharing, crucial for the hashing function to be correct.
146            *self = Self::intern(Either::Right(self.0.clone()));
147            ret
148        }
149
150        /// Deduplicate the valuess by hashing them. This deduplication is crucial for the hashing
151        /// function to be correct. This is the only function allowed to create `Self` values.
152        fn intern(inner: Either<T, Arc<T>>) -> Self {
153            struct InternMapper;
154            impl Mapper for InternMapper {
155                type Value<T: Mappable> = HashMap<T, Arc<T>>;
156            }
157            static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> =
158                LazyLock::new(|| Default::default());
159
160            if INTERNED.read().unwrap().get::<T>().is_none() {
161                INTERNED.write().unwrap().insert::<T>(Default::default());
162            }
163            let read_guard = INTERNED.read().unwrap();
164            if let Some(inner) = (*read_guard)
165                .get::<T>()
166                .unwrap()
167                .get(inner.as_ref().either(|x| x, |x| x.as_ref()))
168            {
169                Self(inner.clone())
170            } else {
171                drop(read_guard);
172                // We clone the value here in the slow path, which makes it possible to avoid an
173                // allocation in the fast path.
174                let raw_val: T = inner.as_ref().either(T::clone, |x| x.as_ref().clone());
175                let arc: Arc<T> = inner.either(Arc::new, |x| x);
176                INTERNED
177                    .write()
178                    .unwrap()
179                    .get_mut::<T>()
180                    .unwrap()
181                    .insert(raw_val, arc.clone());
182                Self(arc)
183            }
184        }
185    }
186
187    /// Hash the pointer; this is only correct if two identical values of `Self` are guaranteed to
188    /// point to the same memory location, which we carefully enforce above.
189    impl<T> std::hash::Hash for HashConsed<T> {
190        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
191            Arc::as_ptr(&self.0).hash(state);
192        }
193    }
194
195    impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
196        fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
197            v.visit(self.inner())
198        }
199    }
200    /// Note: this explores the inner value mutably by cloning and re-hashing afterwards.
201    impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
202    where
203        T: Hash + PartialEq + Eq + Clone + Mappable,
204        V: for<'a> VisitMut<'a, T>,
205    {
206        fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
207            self.with_inner_mut(|inner| v.visit(inner))
208        }
209    }
210}
211
212pub mod hash_by_addr {
213    use std::{
214        hash::{Hash, Hasher},
215        ops::Deref,
216    };
217
218    /// A wrapper around a smart pointer that hashes and compares the contents by the address of
219    /// the pointee.
220    #[derive(Debug, Clone)]
221    pub struct HashByAddr<T>(pub T);
222
223    impl<T: Deref> HashByAddr<T> {
224        fn addr(&self) -> *const T::Target {
225            self.0.deref()
226        }
227    }
228
229    impl<T: Eq + Deref> Eq for HashByAddr<T> {}
230
231    impl<T: PartialEq + Deref> PartialEq for HashByAddr<T> {
232        fn eq(&self, other: &Self) -> bool {
233            std::ptr::addr_eq(self.addr(), other.addr())
234        }
235    }
236
237    impl<T: Hash + Deref> Hash for HashByAddr<T> {
238        fn hash<H: Hasher>(&self, state: &mut H) {
239            self.addr().hash(state);
240        }
241    }
242}
243
244// This is the amount of bytes that need to be left on the stack before increasing the size. It
245// must be at least as large as the stack required by any code that does not call
246// `ensure_sufficient_stack`.
247const RED_ZONE: usize = 100 * 1024; // 100k
248
249// Only the first stack that is pushed, grows exponentially (2^n * STACK_PER_RECURSION) from then
250// on. Values taken from rustc.
251const STACK_PER_RECURSION: usize = 1024 * 1024; // 1MB
252
253/// Grows the stack on demand to prevent stack overflow. Call this in strategic locations to "break
254/// up" recursive calls. E.g. most statement visitors can benefit from this.
255#[inline]
256pub fn ensure_sufficient_stack<R>(f: impl FnOnce() -> R) -> R {
257    stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f)
258}