charon_lib/
common.rs

1use itertools::Itertools;
2
3pub static TAB_INCR: &str = "    ";
4
5/// Custom function to pretty-print elements from an iterator
6/// The output format is:
7/// ```text
8/// [
9///   elem_0,
10///   ...
11///   elem_n
12/// ]
13/// ```
14pub fn pretty_display_list<T>(
15    t_to_string: impl Fn(T) -> String,
16    it: impl IntoIterator<Item = T>,
17) -> String {
18    let mut elems = it
19        .into_iter()
20        .map(t_to_string)
21        .map(|x| format!("  {},\n", x))
22        .peekable();
23    if elems.peek().is_none() {
24        "[]".to_owned()
25    } else {
26        format!("[\n{}]", elems.format(""))
27    }
28}
29
30/// Implement `From` and `TryFrom` to wrap/unwrap enum variants with a single payload.
31#[macro_export]
32macro_rules! impl_from_enum {
33    ($enum:ident::$variant:ident($ty:ty)) => {
34        impl From<$ty> for $enum {
35            fn from(x: $ty) -> Self {
36                $enum::$variant(x)
37            }
38        }
39        impl TryFrom<$enum> for $ty {
40            type Error = ();
41            fn try_from(e: $enum) -> Result<Self, Self::Error> {
42                match e {
43                    $enum::$variant(x) => Ok(x),
44                    _ => Err(()),
45                }
46            }
47        }
48    };
49}
50
51/// Yield `None` then infinitely many `Some(x)`.
52pub fn repeat_except_first<T: Clone>(x: T) -> impl Iterator<Item = Option<T>> {
53    [None].into_iter().chain(std::iter::repeat(Some(x)))
54}
55
56pub mod type_map {
57    use std::{
58        any::{Any, TypeId},
59        collections::HashMap,
60        marker::PhantomData,
61    };
62
63    pub trait Mappable = Any + Send + Sync;
64
65    pub trait Mapper {
66        type Value<T: Mappable>: Mappable;
67    }
68
69    /// A map that maps types to values in a generic manner: we store for each type `T` a value of
70    /// type `M::Value<T>`.
71    pub struct TypeMap<M> {
72        data: HashMap<TypeId, Box<dyn Mappable>>,
73        phantom: PhantomData<M>,
74    }
75
76    impl<M: Mapper> TypeMap<M> {
77        pub fn get<T: Mappable>(&self) -> Option<&M::Value<T>> {
78            self.data
79                .get(&TypeId::of::<T>())
80                // We must be careful to not accidentally cast the box itself as `dyn Any`.
81                .map(|val: &Box<dyn Mappable>| &**val)
82                .and_then(|val: &dyn Mappable| (val as &dyn Any).downcast_ref())
83        }
84
85        pub fn get_mut<T: Mappable>(&mut self) -> Option<&mut M::Value<T>> {
86            self.data
87                .get_mut(&TypeId::of::<T>())
88                // We must be careful to not accidentally cast the box itself as `dyn Any`.
89                .map(|val: &mut Box<dyn Mappable>| &mut **val)
90                .and_then(|val: &mut dyn Mappable| (val as &mut dyn Any).downcast_mut())
91        }
92
93        pub fn insert<T: Mappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
94            self.data
95                .insert(TypeId::of::<T>(), Box::new(val))
96                .and_then(|val: Box<dyn Mappable>| (val as Box<dyn Any>).downcast().ok())
97        }
98
99        pub fn or_insert_with<T: Mappable>(
100            &mut self,
101            f: impl FnOnce() -> M::Value<T>,
102        ) -> &mut M::Value<T> {
103            if self.get::<T>().is_none() {
104                self.insert(f());
105            }
106            self.get_mut::<T>().unwrap()
107        }
108        pub fn or_default<T: Mappable>(&mut self) -> &mut M::Value<T>
109        where
110            M::Value<T>: Default,
111        {
112            self.or_insert_with(|| Default::default())
113        }
114    }
115
116    impl<M> Default for TypeMap<M> {
117        fn default() -> Self {
118            Self {
119                data: Default::default(),
120                phantom: Default::default(),
121            }
122        }
123    }
124}
125
126pub mod hash_by_addr {
127    use serde::{Deserialize, Serialize};
128    use std::{
129        hash::{Hash, Hasher},
130        ops::Deref,
131    };
132
133    /// A wrapper around a smart pointer that hashes and compares the contents by the address of
134    /// the pointee.
135    #[derive(Debug, Clone, Serialize, Deserialize)]
136    pub struct HashByAddr<T>(pub T);
137
138    impl<T: Deref> HashByAddr<T> {
139        fn addr(&self) -> *const T::Target {
140            self.0.deref()
141        }
142    }
143
144    impl<T: Eq + Deref> Eq for HashByAddr<T> {}
145
146    impl<T: PartialEq + Deref> PartialEq for HashByAddr<T> {
147        fn eq(&self, other: &Self) -> bool {
148            std::ptr::addr_eq(self.addr(), other.addr())
149        }
150    }
151
152    impl<T: Hash + Deref> Hash for HashByAddr<T> {
153        fn hash<H: Hasher>(&self, state: &mut H) {
154            self.addr().hash(state);
155        }
156    }
157}
158
159pub mod serialize_map_to_array {
160    use core::{fmt, marker::PhantomData};
161    use std::{
162        collections::hash_map::RandomState,
163        hash::{BuildHasher, Hash},
164    };
165
166    use indexmap::IndexMap;
167    use serde::{
168        Deserialize, Deserializer, Serialize,
169        de::{SeqAccess, Visitor},
170        ser::Serializer,
171    };
172    use serde_state::{DeserializeState, SerializeState};
173
174    #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
175    struct KeyValue<K, V> {
176        key: K,
177        value: V,
178    }
179
180    /// A converter between an `IndexMap` and a sequence of named key-value pairs.
181    pub struct IndexMapToArray<K, V, U = RandomState>(PhantomData<(K, V, U)>);
182
183    impl<K, V, U> IndexMapToArray<K, V, U> {
184        /// Serializes the given `map` to an array of named key-values.
185        pub fn serialize<'a, S>(
186            map: &'a IndexMap<K, V, U>,
187            serializer: S,
188        ) -> Result<S::Ok, S::Error>
189        where
190            K: Serialize,
191            V: Serialize,
192            S: Serializer,
193        {
194            serializer.collect_seq(map.into_iter().map(|(key, value)| KeyValue { key, value }))
195        }
196        pub fn serialize_state<'a, S, State: ?Sized>(
197            map: &'a IndexMap<K, V, U>,
198            state: &State,
199            serializer: S,
200        ) -> Result<S::Ok, S::Error>
201        where
202            K: SerializeState<State>,
203            V: SerializeState<State>,
204            S: Serializer,
205        {
206            serializer.collect_seq(
207                map.into_iter().map(|(key, value)| {
208                    serde_state::WithState::new(KeyValue { key, value }, state)
209                }),
210            )
211        }
212
213        /// Deserializes from an array of named key-values.
214        pub fn deserialize<'de, D>(deserializer: D) -> Result<IndexMap<K, V, U>, D::Error>
215        where
216            K: Deserialize<'de> + Eq + Hash,
217            V: Deserialize<'de>,
218            U: BuildHasher + Default,
219            D: Deserializer<'de>,
220        {
221            struct IndexMapToArrayVisitor<K, V, U>(PhantomData<(K, V, U)>);
222
223            impl<'de, K, V, U> Visitor<'de> for IndexMapToArrayVisitor<K, V, U>
224            where
225                K: Deserialize<'de> + Eq + Hash,
226                V: Deserialize<'de>,
227                U: BuildHasher + Default,
228            {
229                type Value = IndexMap<K, V, U>;
230
231                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
232                    formatter.write_str("a list of key-value objects")
233                }
234
235                fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
236                    let mut map = IndexMap::<K, V, U>::default();
237                    while let Some(entry) = seq.next_element::<KeyValue<K, V>>()? {
238                        map.insert(entry.key, entry.value);
239                    }
240                    Ok(map)
241                }
242            }
243            let map =
244                deserializer.deserialize_seq(IndexMapToArrayVisitor::<K, V, U>(PhantomData))?;
245            Ok(map.into())
246        }
247        /// Deserializes from an array of named key-values.
248        pub fn deserialize_state<'de, D, State>(
249            state: &State,
250            deserializer: D,
251        ) -> Result<IndexMap<K, V, U>, D::Error>
252        where
253            K: DeserializeState<'de, State> + Eq + Hash,
254            V: DeserializeState<'de, State>,
255            U: BuildHasher + Default,
256            D: Deserializer<'de>,
257        {
258            struct IndexMapToArrayVisitor<'a, State, K, V, U>(&'a State, PhantomData<(K, V, U)>);
259
260            impl<'de, State, K, V, U> Visitor<'de> for IndexMapToArrayVisitor<'_, State, K, V, U>
261            where
262                K: DeserializeState<'de, State> + Eq + Hash,
263                V: DeserializeState<'de, State>,
264                U: BuildHasher + Default,
265            {
266                type Value = IndexMap<K, V, U>;
267
268                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
269                    formatter.write_str("a list of key-value objects")
270                }
271
272                fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
273                    let mut map = IndexMap::default();
274                    let seed =
275                        serde_state::__private::wrap_deserialize_seed::<KeyValue<K, V>, _>(self.0);
276                    while let Some(entry) = seq.next_element_seed(seed)? {
277                        map.insert(entry.key, entry.value);
278                    }
279                    Ok(map)
280                }
281            }
282            let map = deserializer
283                .deserialize_seq(IndexMapToArrayVisitor::<_, K, V, U>(state, PhantomData))?;
284            Ok(map.into())
285        }
286    }
287}
288
289// This is the amount of bytes that need to be left on the stack before increasing the size. It
290// must be at least as large as the stack required by any code that does not call
291// `ensure_sufficient_stack`.
292const RED_ZONE: usize = 100 * 1024; // 100k
293
294// Only the first stack that is pushed, grows exponentially (2^n * STACK_PER_RECURSION) from then
295// on. Values taken from rustc.
296const STACK_PER_RECURSION: usize = 1024 * 1024; // 1MB
297
298/// Grows the stack on demand to prevent stack overflow. Call this in strategic locations to "break
299/// up" recursive calls. E.g. most statement visitors can benefit from this.
300#[inline]
301pub fn ensure_sufficient_stack<R>(f: impl FnOnce() -> R) -> R {
302    stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f)
303}