Skip to main content

charon_lib/
common.rs

1use itertools::Itertools;
2use macros::EnumAsGetters;
3
4pub static TAB_INCR: &str = "    ";
5
6/// Custom function to pretty-print elements from an iterator
7/// The output format is:
8/// ```text
9/// [
10///   elem_0,
11///   ...
12///   elem_n
13/// ]
14/// ```
15pub fn pretty_display_list<T>(
16    t_to_string: impl Fn(T) -> String,
17    it: impl IntoIterator<Item = T>,
18) -> String {
19    let mut elems = it
20        .into_iter()
21        .map(t_to_string)
22        .map(|x| format!("  {},\n", x))
23        .peekable();
24    if elems.peek().is_none() {
25        "[]".to_owned()
26    } else {
27        format!("[\n{}]", elems.format(""))
28    }
29}
30
31/// Implement `From` and `TryFrom` to wrap/unwrap enum variants with a single payload.
32#[macro_export]
33macro_rules! impl_from_enum {
34    ($enum:ident::$variant:ident($ty:ty)) => {
35        impl From<$ty> for $enum {
36            fn from(x: $ty) -> Self {
37                $enum::$variant(x)
38            }
39        }
40        impl TryFrom<$enum> for $ty {
41            type Error = ();
42            fn try_from(e: $enum) -> Result<Self, Self::Error> {
43                match e {
44                    $enum::$variant(x) => Ok(x),
45                    _ => Err(()),
46                }
47            }
48        }
49    };
50}
51
52/// Yield `None` then infinitely many `Some(x)`.
53pub fn repeat_except_first<T: Clone>(x: T) -> impl Iterator<Item = Option<T>> {
54    [None].into_iter().chain(std::iter::repeat(Some(x)))
55}
56
57/// An enum to manage potentially-cyclic computations.
58#[derive(Debug, EnumAsGetters)]
59pub enum CycleDetector<T> {
60    /// We haven't analyzed this yet.
61    Unprocessed,
62    /// Sentinel value that we set when starting the computation on an item. If we ever encounter
63    /// this, we know we encountered a loop that we can't handle.
64    Processing,
65    /// Sentinel value we put when encountering a cycle, so we can know that happened.
66    Cyclic,
67    /// The final result of the computation.
68    Processed(T),
69}
70
71impl<T> CycleDetector<T> {
72    /// If this item hadn't been processed, return `true` and record it as `Processing`, otherwise
73    /// return `false`. If this item is already processing, record a cycle.
74    pub fn start_processing(&mut self) -> bool {
75        match self {
76            CycleDetector::Unprocessed => {
77                *self = CycleDetector::Processing;
78                true
79            }
80            CycleDetector::Processing => {
81                *self = CycleDetector::Cyclic;
82                false
83            }
84            CycleDetector::Cyclic | CycleDetector::Processed(_) => false,
85        }
86    }
87
88    pub fn done_processing(&mut self, x: T) {
89        *self = CycleDetector::Processed(x)
90    }
91}
92
93impl<T> Default for CycleDetector<T> {
94    fn default() -> Self {
95        Self::Unprocessed
96    }
97}
98
99pub mod type_map {
100    use std::{
101        any::{Any, TypeId},
102        collections::HashMap,
103        marker::PhantomData,
104    };
105
106    pub trait Mappable = Any + Send + Sync;
107
108    pub trait Mapper {
109        type Value<T: Mappable>: Mappable;
110    }
111
112    /// A map that maps types to values in a generic manner: we store for each type `T` a value of
113    /// type `M::Value<T>`.
114    pub struct TypeMap<M> {
115        data: HashMap<TypeId, Box<dyn Mappable>>,
116        phantom: PhantomData<M>,
117    }
118
119    impl<M: Mapper> TypeMap<M> {
120        pub fn get<T: Mappable>(&self) -> Option<&M::Value<T>> {
121            self.data
122                .get(&TypeId::of::<T>())
123                // We must be careful to not accidentally cast the box itself as `dyn Any`.
124                .map(|val: &Box<dyn Mappable>| &**val)
125                .and_then(|val: &dyn Mappable| (val as &dyn Any).downcast_ref())
126        }
127
128        pub fn get_mut<T: Mappable>(&mut self) -> Option<&mut M::Value<T>> {
129            self.data
130                .get_mut(&TypeId::of::<T>())
131                // We must be careful to not accidentally cast the box itself as `dyn Any`.
132                .map(|val: &mut Box<dyn Mappable>| &mut **val)
133                .and_then(|val: &mut dyn Mappable| (val as &mut dyn Any).downcast_mut())
134        }
135
136        pub fn insert<T: Mappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
137            self.data
138                .insert(TypeId::of::<T>(), Box::new(val))
139                .and_then(|val: Box<dyn Mappable>| (val as Box<dyn Any>).downcast().ok())
140        }
141
142        pub fn or_insert_with<T: Mappable>(
143            &mut self,
144            f: impl FnOnce() -> M::Value<T>,
145        ) -> &mut M::Value<T> {
146            if self.get::<T>().is_none() {
147                self.insert(f());
148            }
149            self.get_mut::<T>().unwrap()
150        }
151        pub fn or_default<T: Mappable>(&mut self) -> &mut M::Value<T>
152        where
153            M::Value<T>: Default,
154        {
155            self.or_insert_with(|| Default::default())
156        }
157    }
158
159    impl<M> Default for TypeMap<M> {
160        fn default() -> Self {
161            Self {
162                data: Default::default(),
163                phantom: Default::default(),
164            }
165        }
166    }
167}
168
169pub mod hash_by_addr {
170    use serde::{Deserialize, Serialize};
171    use std::{
172        hash::{Hash, Hasher},
173        ops::Deref,
174    };
175
176    /// A wrapper around a smart pointer that hashes and compares the contents by the address of
177    /// the pointee.
178    #[derive(Debug, Clone, Serialize, Deserialize)]
179    pub struct HashByAddr<T>(pub T);
180
181    impl<T: Deref> HashByAddr<T> {
182        fn addr(&self) -> *const T::Target {
183            self.0.deref()
184        }
185    }
186
187    impl<T: Eq + Deref> Eq for HashByAddr<T> {}
188
189    impl<T: PartialEq + Deref> PartialEq for HashByAddr<T> {
190        fn eq(&self, other: &Self) -> bool {
191            std::ptr::addr_eq(self.addr(), other.addr())
192        }
193    }
194
195    impl<T: Hash + Deref> Hash for HashByAddr<T> {
196        fn hash<H: Hasher>(&self, state: &mut H) {
197            self.addr().hash(state);
198        }
199    }
200}
201
202pub mod serialize_map_to_array {
203    use core::{fmt, marker::PhantomData};
204    use std::{
205        collections::hash_map::RandomState,
206        hash::{BuildHasher, Hash},
207    };
208
209    use indexmap::IndexMap as SeqHashMap;
210    use serde::{
211        Deserialize, Deserializer, Serialize,
212        de::{SeqAccess, Visitor},
213        ser::Serializer,
214    };
215    use serde_state::{DeserializeState, SerializeState};
216
217    #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
218    struct KeyValue<K, V> {
219        key: K,
220        value: V,
221    }
222
223    /// A converter between an `SeqHashMap` and a sequence of named key-value pairs.
224    pub struct SeqHashMapToArray<K, V, U = RandomState>(PhantomData<(K, V, U)>);
225
226    impl<K, V, U> SeqHashMapToArray<K, V, U> {
227        /// Serializes the given `map` to an array of named key-values.
228        pub fn serialize<'a, S>(
229            map: &'a SeqHashMap<K, V, U>,
230            serializer: S,
231        ) -> Result<S::Ok, S::Error>
232        where
233            K: Serialize,
234            V: Serialize,
235            S: Serializer,
236        {
237            serializer.collect_seq(map.into_iter().map(|(key, value)| KeyValue { key, value }))
238        }
239        pub fn serialize_state<'a, S, State: ?Sized>(
240            map: &'a SeqHashMap<K, V, U>,
241            state: &State,
242            serializer: S,
243        ) -> Result<S::Ok, S::Error>
244        where
245            K: SerializeState<State>,
246            V: SerializeState<State>,
247            S: Serializer,
248        {
249            serializer.collect_seq(
250                map.into_iter().map(|(key, value)| {
251                    serde_state::WithState::new(KeyValue { key, value }, state)
252                }),
253            )
254        }
255
256        /// Deserializes from an array of named key-values.
257        pub fn deserialize<'de, D>(deserializer: D) -> Result<SeqHashMap<K, V, U>, D::Error>
258        where
259            K: Deserialize<'de> + Eq + Hash,
260            V: Deserialize<'de>,
261            U: BuildHasher + Default,
262            D: Deserializer<'de>,
263        {
264            struct SeqHashMapToArrayVisitor<K, V, U>(PhantomData<(K, V, U)>);
265
266            impl<'de, K, V, U> Visitor<'de> for SeqHashMapToArrayVisitor<K, V, U>
267            where
268                K: Deserialize<'de> + Eq + Hash,
269                V: Deserialize<'de>,
270                U: BuildHasher + Default,
271            {
272                type Value = SeqHashMap<K, V, U>;
273
274                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
275                    formatter.write_str("a list of key-value objects")
276                }
277
278                fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
279                    let mut map = SeqHashMap::<K, V, U>::default();
280                    while let Some(entry) = seq.next_element::<KeyValue<K, V>>()? {
281                        map.insert(entry.key, entry.value);
282                    }
283                    Ok(map)
284                }
285            }
286            let map =
287                deserializer.deserialize_seq(SeqHashMapToArrayVisitor::<K, V, U>(PhantomData))?;
288            Ok(map.into())
289        }
290        /// Deserializes from an array of named key-values.
291        pub fn deserialize_state<'de, D, State>(
292            state: &State,
293            deserializer: D,
294        ) -> Result<SeqHashMap<K, V, U>, D::Error>
295        where
296            K: DeserializeState<'de, State> + Eq + Hash,
297            V: DeserializeState<'de, State>,
298            U: BuildHasher + Default,
299            D: Deserializer<'de>,
300        {
301            struct SeqHashMapToArrayVisitor<'a, State, K, V, U>(&'a State, PhantomData<(K, V, U)>);
302
303            impl<'de, State, K, V, U> Visitor<'de> for SeqHashMapToArrayVisitor<'_, State, K, V, U>
304            where
305                K: DeserializeState<'de, State> + Eq + Hash,
306                V: DeserializeState<'de, State>,
307                U: BuildHasher + Default,
308            {
309                type Value = SeqHashMap<K, V, U>;
310
311                fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
312                    formatter.write_str("a list of key-value objects")
313                }
314
315                fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
316                    let mut map = SeqHashMap::default();
317                    let seed =
318                        serde_state::__private::wrap_deserialize_seed::<KeyValue<K, V>, _>(self.0);
319                    while let Some(entry) = seq.next_element_seed(seed)? {
320                        map.insert(entry.key, entry.value);
321                    }
322                    Ok(map)
323                }
324            }
325            let map = deserializer
326                .deserialize_seq(SeqHashMapToArrayVisitor::<_, K, V, U>(state, PhantomData))?;
327            Ok(map.into())
328        }
329    }
330}
331
332// This is the amount of bytes that need to be left on the stack before increasing the size. It
333// must be at least as large as the stack required by any code that does not call
334// `ensure_sufficient_stack`.
335const RED_ZONE: usize = 100 * 1024; // 100k
336
337// Only the first stack that is pushed, grows exponentially (2^n * STACK_PER_RECURSION) from then
338// on. Values taken from rustc.
339const STACK_PER_RECURSION: usize = 1024 * 1024; // 1MB
340
341/// Grows the stack on demand to prevent stack overflow. Call this in strategic locations to "break
342/// up" recursive calls. E.g. most statement visitors can benefit from this.
343#[inline]
344pub fn ensure_sufficient_stack<R>(f: impl FnOnce() -> R) -> R {
345    stacker::maybe_grow(RED_ZONE, STACK_PER_RECURSION, f)
346}
347
348/// Returns the values of the command-line options that match `find_arg`. The options are built-in
349/// to be of the form `--arg=value` or `--arg value`.
350pub fn arg_values<'a, T: AsRef<str>>(
351    args: &'a [T],
352    needle: &'a str,
353) -> impl Iterator<Item = &'a str> {
354    struct ArgFilter<'a, T> {
355        args: std::slice::Iter<'a, T>,
356        needle: &'a str,
357    }
358    impl<'a, T: AsRef<str>> Iterator for ArgFilter<'a, T> {
359        type Item = &'a str;
360        fn next(&mut self) -> Option<Self::Item> {
361            while let Some(arg) = self.args.next() {
362                let mut split_arg = arg.as_ref().splitn(2, '=');
363                if split_arg.next() == Some(self.needle) {
364                    return match split_arg.next() {
365                        // `--arg=value` form
366                        arg @ Some(_) => arg,
367                        // `--arg value` form
368                        None => self.args.next().map(|x| x.as_ref()),
369                    };
370                }
371            }
372            None
373        }
374    }
375    ArgFilter {
376        args: args.iter(),
377        needle,
378    }
379}
380
381pub fn arg_value<'a, T: AsRef<str>>(args: &'a [T], needle: &'a str) -> Option<&'a str> {
382    arg_values(args, needle).next()
383}