Skip to main content

charon_lib/
common.rs

1use itertools::Itertools;
2use macros::EnumAsGetters;
3
4pub static TAB_INCR: &str = "    ";
5
6/// Custom function to pretty-print elements from an iterator
7/// The output format is:
8/// ```text
9/// [
10///   elem_0,
11///   ...
12///   elem_n
13/// ]
14/// ```
15pub fn pretty_display_list<T>(
16    t_to_string: impl Fn(T) -> String,
17    it: impl IntoIterator<Item = T>,
18) -> String {
19    let mut elems = it
20        .into_iter()
21        .map(t_to_string)
22        .map(|x| format!("  {},\n", x))
23        .peekable();
24    if elems.peek().is_none() {
25        "[]".to_owned()
26    } else {
27        format!("[\n{}]", elems.format(""))
28    }
29}
30
31/// Implement `From` and `TryFrom` to wrap/unwrap enum variants with a single payload.
32#[macro_export]
33macro_rules! impl_from_enum {
34    ($enum:ident::$variant:ident($ty:ty)) => {
35        impl From<$ty> for $enum {
36            fn from(x: $ty) -> Self {
37                $enum::$variant(x)
38            }
39        }
40        impl TryFrom<$enum> for $ty {
41            type Error = ();
42            fn try_from(e: $enum) -> Result<Self, Self::Error> {
43                match e {
44                    $enum::$variant(x) => Ok(x),
45                    _ => Err(()),
46                }
47            }
48        }
49    };
50}
51
52/// Yield `None` then infinitely many `Some(x)`.
53pub fn repeat_except_first<T: Clone>(x: T) -> impl Iterator<Item = Option<T>> {
54    [None].into_iter().chain(std::iter::repeat(Some(x)))
55}
56
57/// An enum to manage potentially-cyclic computations.
58#[derive(Debug, EnumAsGetters)]
59pub enum CycleDetector<T> {
60    /// We haven't analyzed this yet.
61    Unprocessed,
62    /// Sentinel value that we set when starting the computation on an item. If we ever encounter
63    /// this, we know we encountered a loop that we can't handle.
64    Processing,
65    /// Sentinel value we put when encountering a cycle, so we can know that happened.
66    Cyclic,
67    /// The final result of the computation.
68    Processed(T),
69}
70
71impl<T> CycleDetector<T> {
72    /// If this item hadn't been processed, return `true` and record it as `Processing`, otherwise
73    /// return `false`. If this item is already processing, record a cycle.
74    pub fn start_processing(&mut self) -> bool {
75        match self {
76            CycleDetector::Unprocessed => {
77                *self = CycleDetector::Processing;
78                true
79            }
80            CycleDetector::Processing => {
81                *self = CycleDetector::Cyclic;
82                false
83            }
84            CycleDetector::Cyclic | CycleDetector::Processed(_) => false,
85        }
86    }
87
88    pub fn done_processing(&mut self, x: T) {
89        *self = CycleDetector::Processed(x)
90    }
91}
92
93impl<T> Default for CycleDetector<T> {
94    fn default() -> Self {
95        CycleDetector::Unprocessed
96    }
97}
98
99pub mod type_map {
100    use rustc_hash::FxHashMap;
101    use std::{
102        any::{Any, TypeId},
103        marker::PhantomData,
104    };
105
106    pub trait Mappable: Any + Send + Sync {}
107    impl<T> Mappable for T where T: Any + Send + Sync {}
108
109    pub trait Mapper {
110        type Value<T: Mappable>: Mappable;
111    }
112
113    /// A map that maps types to values in a generic manner: we store for each type `T` a value of
114    /// type `M::Value<T>`.
115    pub struct TypeMap<M> {
116        data: FxHashMap<TypeId, Box<dyn Mappable>>,
117        phantom: PhantomData<M>,
118    }
119
120    impl<M: Mapper> TypeMap<M> {
121        pub fn get<T: Mappable>(&self) -> Option<&M::Value<T>> {
122            self.data
123                .get(&TypeId::of::<T>())
124                // We must be careful to not accidentally cast the box itself as `dyn Any`.
125                .map(|val: &Box<dyn Mappable>| &**val)
126                .and_then(|val: &dyn Mappable| (val as &dyn Any).downcast_ref())
127        }
128
129        pub fn get_mut<T: Mappable>(&mut self) -> Option<&mut M::Value<T>> {
130            self.data
131                .get_mut(&TypeId::of::<T>())
132                // We must be careful to not accidentally cast the box itself as `dyn Any`.
133                .map(|val: &mut Box<dyn Mappable>| &mut **val)
134                .and_then(|val: &mut dyn Mappable| (val as &mut dyn Any).downcast_mut())
135        }
136
137        pub fn insert<T: Mappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
138            self.data
139                .insert(TypeId::of::<T>(), Box::new(val))
140                .and_then(|val: Box<dyn Mappable>| (val as Box<dyn Any>).downcast().ok())
141        }
142
143        pub fn or_insert_with<T: Mappable>(
144            &mut self,
145            f: impl FnOnce() -> M::Value<T>,
146        ) -> &mut M::Value<T> {
147            if self.get::<T>().is_none() {
148                self.insert(f());
149            }
150            self.get_mut::<T>().unwrap()
151        }
152        pub fn or_default<T: Mappable>(&mut self) -> &mut M::Value<T>
153        where
154            M::Value<T>: Default,
155        {
156            self.or_insert_with(Default::default)
157        }
158    }
159
160    impl<M> Default for TypeMap<M> {
161        fn default() -> Self {
162            Self {
163                data: Default::default(),
164                phantom: Default::default(),
165            }
166        }
167    }
168}
169
170pub mod hash_by_addr {
171    use serde::{Deserialize, Serialize};
172    use std::{
173        hash::{Hash, Hasher},
174        ops::Deref,
175    };
176
177    /// A wrapper around a smart pointer that hashes and compares the contents by the address of
178    /// the pointee.
179    #[derive(Debug, Clone, Serialize, Deserialize)]
180    pub struct HashByAddr<T>(pub T);
181
182    impl<T: Deref> HashByAddr<T> {
183        fn addr(&self) -> *const T::Target {
184            self.0.deref()
185        }
186    }
187
188    impl<T: Eq + Deref> Eq for HashByAddr<T> {}
189
190    impl<T: PartialEq + Deref> PartialEq for HashByAddr<T> {
191        fn eq(&self, other: &Self) -> bool {
192            std::ptr::addr_eq(self.addr(), other.addr())
193        }
194    }
195
196    impl<T: Hash + Deref> Hash for HashByAddr<T> {
197        fn hash<H: Hasher>(&self, state: &mut H) {
198            self.addr().hash(state);
199        }
200    }
201}
202
203pub mod serialize_map_to_array {
204    use core::{fmt, marker::PhantomData};
205    use std::{
206        collections::hash_map::RandomState,
207        hash::{BuildHasher, Hash},
208    };
209
210    use indexmap::IndexMap as SeqHashMap;
211    use serde::{
212        Deserialize, Deserializer, Serialize,
213        de::{SeqAccess, Visitor},
214        ser::Serializer,
215    };
216    use serde_state::{DeserializeState, SerializeState};
217
218    #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
219    struct KeyValue<K, V> {
220        key: K,
221        value: V,
222    }
223
224    /// A converter between an `SeqHashMap` and a sequence of named key-value pairs.
225    pub struct SeqHashMapToArray<K, V, U = RandomState>(PhantomData<(K, V, U)>);
226
227    impl<K, V, U> SeqHashMapToArray<K, V, U> {
228        /// Serializes the given `map` to an array of named key-values.
229        pub fn serialize<S>(map: &SeqHashMap<K, V, U>, serializer: S) -> Result<S::Ok, S::Error>
230        where
231            K: Serialize,
232            V: Serialize,
233            S: Serializer,
234        {
235            serializer.collect_seq(map.into_iter().map(|(key, value)| KeyValue { key, value }))
236        }
237        pub fn serialize_state<S, State: ?Sized>(
238            map: &SeqHashMap<K, V, U>,
239            state: &State,
240            serializer: S,
241        ) -> Result<S::Ok, S::Error>
242        where
243            K: SerializeState<State>,
244            V: SerializeState<State>,
245            S: Serializer,
246        {
247            serializer.collect_seq(
248                map.into_iter().map(|(key, value)| {
249                    serde_state::WithState::new(KeyValue { key, value }, state)
250                }),
251            )
252        }
253
254        /// Deserializes from an array of named key-values.
255        pub fn deserialize<'de, D>(deserializer: D) -> Result<SeqHashMap<K, V, U>, D::Error>
256        where
257            K: Deserialize<'de> + Eq + Hash,
258            V: Deserialize<'de>,
259            U: BuildHasher + Default,
260            D: Deserializer<'de>,
261        {
262            struct SeqHashMapToArrayVisitor<K, V, U>(PhantomData<(K, V, U)>);
263
264            impl<'de, K, V, U> Visitor<'de> for SeqHashMapToArrayVisitor<K, V, U>
265            where
266                K: Deserialize<'de> + Eq + Hash,
267                V: Deserialize<'de>,
268                U: BuildHasher + Default,
269            {
270                type Value = SeqHashMap<K, V, U>;
271
272                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
273                    formatter.write_str("a list of key-value objects")
274                }
275
276                fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
277                    let mut map = SeqHashMap::<K, V, U>::default();
278                    while let Some(entry) = seq.next_element::<KeyValue<K, V>>()? {
279                        map.insert(entry.key, entry.value);
280                    }
281                    Ok(map)
282                }
283            }
284            let map =
285                deserializer.deserialize_seq(SeqHashMapToArrayVisitor::<K, V, U>(PhantomData))?;
286            Ok(map)
287        }
288        /// Deserializes from an array of named key-values.
289        pub fn deserialize_state<'de, D, State: ?Sized>(
290            state: &State,
291            deserializer: D,
292        ) -> Result<SeqHashMap<K, V, U>, D::Error>
293        where
294            K: DeserializeState<'de, State> + Eq + Hash,
295            V: DeserializeState<'de, State>,
296            U: BuildHasher + Default,
297            D: Deserializer<'de>,
298        {
299            struct SeqHashMapToArrayVisitor<'a, State: ?Sized, K, V, U>(
300                &'a State,
301                PhantomData<(K, V, U)>,
302            );
303
304            impl<'de, State: ?Sized, K, V, U> Visitor<'de> for SeqHashMapToArrayVisitor<'_, State, K, V, U>
305            where
306                K: DeserializeState<'de, State> + Eq + Hash,
307                V: DeserializeState<'de, State>,
308                U: BuildHasher + Default,
309            {
310                type Value = SeqHashMap<K, V, U>;
311
312                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
313                    formatter.write_str("a list of key-value objects")
314                }
315
316                fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
317                    let mut map = SeqHashMap::default();
318                    let seed =
319                        serde_state::__private::wrap_deserialize_seed::<KeyValue<K, V>, _>(self.0);
320                    while let Some(entry) = seq.next_element_seed(seed)? {
321                        map.insert(entry.key, entry.value);
322                    }
323                    Ok(map)
324                }
325            }
326            let map = deserializer
327                .deserialize_seq(SeqHashMapToArrayVisitor::<_, K, V, U>(state, PhantomData))?;
328            Ok(map)
329        }
330    }
331}
332
333// This is the amount of bytes that need to be left on the stack before increasing the size. It
334// must be at least as large as the stack required by any code that does not call
335// `ensure_sufficient_stack`.
336const RED_ZONE: usize = 100 * 1024; // 100k
337
338// Only the first stack that is pushed, grows exponentially (2^n * STACK_PER_RECURSION) from then
339// on. Values taken from rustc.
340const STACK_PER_RECURSION: usize = 1024 * 1024; // 1MB
341
342/// Grows the stack on demand to prevent stack overflow. Call this in strategic locations to "break
343/// up" recursive calls. E.g. most statement visitors can benefit from this.
344#[inline]
345pub fn ensure_sufficient_stack<R>(f: impl FnOnce() -> R) -> R {
346    stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f)
347}
348
349/// Returns the values of the command-line options that match `find_arg`. The options are built-in
350/// to be of the form `--arg=value` or `--arg value`.
351pub fn arg_values<'a, T: AsRef<str>>(
352    args: &'a [T],
353    needle: &'a str,
354) -> impl Iterator<Item = &'a str> {
355    struct ArgFilter<'a, T> {
356        args: std::slice::Iter<'a, T>,
357        needle: &'a str,
358    }
359    impl<'a, T: AsRef<str>> Iterator for ArgFilter<'a, T> {
360        type Item = &'a str;
361        fn next(&mut self) -> Option<Self::Item> {
362            while let Some(arg) = self.args.next() {
363                let mut split_arg = arg.as_ref().splitn(2, '=');
364                if split_arg.next() == Some(self.needle) {
365                    return match split_arg.next() {
366                        // `--arg=value` form
367                        arg @ Some(_) => arg,
368                        // `--arg value` form
369                        None => self.args.next().map(|x| x.as_ref()),
370                    };
371                }
372            }
373            None
374        }
375    }
376    ArgFilter {
377        args: args.iter(),
378        needle,
379    }
380}
381
382pub fn arg_value<'a, T: AsRef<str>>(args: &'a [T], needle: &'a str) -> Option<&'a str> {
383    arg_values(args, needle).next()
384}