Skip to main content

charon_lib/ast/
hash_cons.rs

1use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
2use serde::{Deserialize, Serialize};
3use std::hash::Hash;
4use std::ops::{ControlFlow, Deref};
5use std::sync::Arc;
6
7use crate::common::hash_by_addr::HashByAddr;
8use crate::common::type_map::Mappable;
9
10/// Hash-consed data structure: a reference-counted wrapper that guarantees that two equal
11/// value will be stored at the same address. This makes it possible to use the pointer address
12/// as a hash value.
13// Warning: a `derive` should not introduce a way to create a new `HashConsed` value without
14// going through the interning table.
15#[derive(PartialEq, Eq, Hash)]
16pub struct HashConsed<T>(HashByAddr<Arc<T>>);
17
18impl<T> Clone for HashConsed<T> {
19    fn clone(&self) -> Self {
20        Self(self.0.clone())
21    }
22}
23
24impl<T> HashConsed<T> {
25    pub fn inner(&self) -> &T {
26        self.0.0.as_ref()
27    }
28}
29
30pub trait HashConsable: Hash + PartialEq + Eq + Clone + Mappable {}
31impl<T> HashConsable for T where T: Hash + PartialEq + Eq + Clone + Mappable {}
32
33/// Unique id identifying a hashconsed value amongst all hashconsed values.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
35pub struct HashConsId(u64);
36
37// Private module that contains the static we'll use as interning map. A value of type
38// `HashCons` MUST NOT be created in any other way than this table, else hashing and euqality
39// on it will be broken. Note that this likely means that if a crate uses charon both as a
40// direct dependency and as a dylib, then the static will be duplicated, causing hashing and
41// equality on `HashCons` to be broken.
42mod intern_table {
43    use rustc_hash::FxBuildHasher;
44    use std::borrow::Borrow;
45    use std::sync::atomic::{AtomicU64, Ordering};
46    use std::sync::{Arc, LazyLock, RwLock};
47
48    use super::{HashConsId, HashConsable, HashConsed};
49    use crate::common::hash_by_addr::HashByAddr;
50    use crate::common::type_map::{Mappable, Mapper, TypeMap};
51
52    type SeqHashMap<K, V> = indexmap::IndexMap<K, V, FxBuildHasher>;
53
54    // Only way we create a `HashConsId`.
55    fn fresh_id() -> HashConsId {
56        static ID: AtomicU64 = AtomicU64::new(0);
57        HashConsId(ID.fetch_add(1, Ordering::Relaxed))
58    }
59
60    // This is a static mutable `SeqHashSet<Arc<T>>` that records for each `T` value a unique
61    // `Arc<T>` that contains the same value. Values inside the set are hashed/compared
62    // as is normal for `T`.
63    // Once we've gotten an `Arc` out of the set however, we're sure that "T-equality"
64    // implies address-equality, hence the `HashByAddr` wrapper preserves correct equality
65    // and hashing behavior.
66    // Note that we also store a `HashConsId` for each item so this is a map instead of a set, but
67    // what matters is really the map keys.
68    struct InternMapper;
69    impl Mapper for InternMapper {
70        type Value<T: Mappable> = SeqHashMap<Arc<T>, HashConsId>;
71    }
72    static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> = LazyLock::new(Default::default);
73
74    // The excessive generality is to make it work for both `U = T` and `U = Arc<T>`.
75    pub fn intern<T: HashConsable, U>(inner: U) -> HashConsed<T>
76    where
77        Arc<T>: Borrow<U>,
78        U: Into<Arc<T>> + std::hash::Hash,
79        U: indexmap::Equivalent<Arc<T>>,
80    {
81        // Fast read-only check.
82        let arc = if let read_guard = INTERNED.read().unwrap()
83            && let Some(map) = read_guard.get::<T>()
84            && let Some((arc, _id)) = map.get_key_value(&inner)
85        {
86            arc.clone()
87        } else {
88            // Concurrent access is possible right here, so we have to check everything again.
89            let mut write_guard = INTERNED.write().unwrap();
90            let map: &mut SeqHashMap<Arc<T>, _> = write_guard.or_default::<T>();
91            if let Some((arc, _id)) = map.get_key_value(&inner) {
92                arc.clone()
93            } else {
94                let arc: Arc<T> = inner.into();
95                map.insert(arc.clone(), fresh_id());
96                arc
97            }
98        };
99        HashConsed(HashByAddr(arc))
100    }
101
102    /// Mutate the contents in-place if possible.
103    pub fn mutate_in_place<T: HashConsable, R, F: FnOnce(&mut T) -> R>(
104        x: &mut HashConsed<T>,
105        f: F,
106    ) -> Result<R, F> {
107        let arc = &mut x.0.0;
108        // Every value has at least two pointers: the current value and the one stored in the
109        // global map. If there are exactly two, we may mutate directly by discarding the one in
110        // the global map temporarily.
111        if Arc::strong_count(arc) != 2 {
112            return Err(f);
113        }
114        {
115            // Take the write guard just long enough to drop the other `Arc` to this value.
116            let mut write_guard = INTERNED.write().unwrap();
117            if let Some((other_arc, _)) = write_guard.or_default::<T>().swap_remove_entry(&*arc) {
118                drop(other_arc);
119            } else {
120                // Nothing was removed, early return.
121                return Err(f);
122            }
123            // The Arc was removed from the map; `x` is invalid as interning the same value would
124            // result in a different pointer. NO MORE EARLY RETURN until we fix that.
125        }
126        // If we are still the sole owner, we can now mutate in-place.
127        let ret = match Arc::get_mut(arc) {
128            Some(val) => Ok(f(val)),
129            None => Err(f),
130        };
131        // Re-establish the interning invariant. If the same value was added to the map in the
132        // meantime, we'll get a pointer to that.
133        *x = HashConsed::from_arc(arc.clone());
134        ret
135    }
136
137    /// Identify this value uniquely amongst values of its type. The id depends on insertion
138    /// order into the interning table which makes them in principle deterministic.
139    pub fn id<T: HashConsable>(x: &HashConsed<T>) -> HashConsId {
140        // `HashConsed` can only be constructed via `intern`, so we know this value exists in the
141        // table.
142        let read_guard = INTERNED.read().unwrap();
143        let map = read_guard.get::<T>().unwrap();
144        let (_arc, id) = map.get_key_value(&x.0.0).unwrap();
145        *id
146    }
147}
148
149impl<T> HashConsed<T>
150where
151    T: HashConsable,
152{
153    /// Deduplicate the values by hashing them. This deduplication is crucial for the hashing
154    /// function to be correct. This is the only function allowed to create `Self` values.
155    pub fn new(inner: T) -> Self {
156        intern_table::intern(inner)
157    }
158    /// Rarely used: in case we already have an `Arc`, may avoid an allocation.
159    pub fn from_arc(inner: Arc<T>) -> Self {
160        intern_table::intern(inner)
161    }
162
163    /// Clones if needed to get mutable access to the inner value.
164    pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
165        match intern_table::mutate_in_place(self, f) {
166            Ok(r) => r,
167            Err(f) => {
168                // The value is behind a shared `Arc`, we clone it in order to mutate it.
169                let mut value = self.inner().clone();
170                let ret = f(&mut value);
171                // Re-intern the new value.
172                *self = Self::new(value);
173                ret
174            }
175        }
176    }
177
178    pub fn id(&self) -> HashConsId {
179        intern_table::id(self)
180    }
181}
182
183impl<T> Deref for HashConsed<T> {
184    type Target = T;
185    fn deref(&self) -> &Self::Target {
186        self.inner()
187    }
188}
189
190impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        // Hide the `HashByAddr` wrapper.
193        f.debug_tuple("HashConsed").field(self.inner()).finish()
194    }
195}
196
197impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
198    fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
199        v.visit(self.inner())
200    }
201}
202/// Note: this explores the inner value mutably by cloning and re-hashing afterwards.
203impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
204where
205    T: HashConsable,
206    V: for<'a> VisitMut<'a, T>,
207{
208    fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
209        self.with_inner_mut(|inner| v.visit(inner))
210    }
211}
212
213/// `HashCons` supports serializing each value to a unique id in order to serialize
214/// highly-shared values without explosion.
215///
216/// Note that the deduplication scheme is highly order-dependent: we serialize the real value
217/// the first time it comes up, and use ids only subsequent times. This relies on the fact that
218/// `derive(Serialize, Deserialize)` traverse the value in the same order.
219pub use serialize::{HashConsDedupSerializer, HashConsSerializerState};
220mod serialize {
221    use indexmap::IndexMap as SeqHashMap;
222    use serde::{Deserialize, Serialize};
223    use serde_state::{DeserializeState, SerializeState};
224    use std::any::type_name;
225    use std::cell::RefCell;
226    use std::collections::HashSet;
227
228    use super::{HashConsId, HashConsable, HashConsed};
229    use crate::common::type_map::{Mappable, Mapper, TypeMap};
230
231    pub trait HashConsSerializerState: Sized {
232        /// Record that this type is being serialized. Return `None` if we're not deduplicating
233        /// values, otherwise return whether this item was newly recorded.
234        fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool>;
235        /// Record that we deserialized this type.
236        fn record_deserialized<T: Mappable>(&self, id: HashConsId, value: HashConsed<T>);
237        /// Find the previously-deserialized type with that id.
238        fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>>;
239    }
240
241    impl HashConsSerializerState for () {
242        fn record_serialized<T: Mappable>(&self, _id: HashConsId) -> Option<bool> {
243            None
244        }
245        fn record_deserialized<T: Mappable>(&self, _id: HashConsId, _value: HashConsed<T>) {}
246        fn get_deserialized_val<T: Mappable>(&self, _id: HashConsId) -> Option<HashConsed<T>> {
247            None
248        }
249    }
250
251    struct SerializeTableMapper;
252    impl Mapper for SerializeTableMapper {
253        type Value<T: Mappable> = HashSet<HashConsId>;
254    }
255    struct DeserializeTableMapper;
256    impl Mapper for DeserializeTableMapper {
257        type Value<T: Mappable> = SeqHashMap<HashConsId, HashConsed<T>>;
258    }
259    #[derive(Default)]
260    pub struct HashConsDedupSerializer {
261        // Table used for serialization.
262        ser: RefCell<TypeMap<SerializeTableMapper>>,
263        // Table used for deserialization.
264        de: RefCell<TypeMap<DeserializeTableMapper>>,
265    }
266    impl HashConsSerializerState for HashConsDedupSerializer {
267        fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool> {
268            Some(self.ser.borrow_mut().or_default::<T>().insert(id))
269        }
270        fn record_deserialized<T: Mappable>(&self, id: HashConsId, val: HashConsed<T>) {
271            self.de.borrow_mut().or_default::<T>().insert(id, val);
272        }
273        fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>> {
274            self.de
275                .borrow()
276                .get::<T>()
277                .and_then(|map| map.get(&id))
278                .cloned()
279        }
280    }
281
282    /// A dummy enum used when serializing/deserializing a `HashConsed<T>`.
283    #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
284    #[serde_state(state_implements = HashConsSerializerState)]
285    enum SerRepr<T> {
286        /// A value represented normally, accompanied by its id. This is emitted the first time
287        /// we serialize a given value: subsequent times will use `SerRepr::Deduplicate`
288        /// instead.
289        HashConsedValue(#[serde_state(stateless)] HashConsId, T),
290        /// A value represented by its id. The actual value must have been emitted as a
291        /// `SerRepr::Value` with that same id earlier.
292        #[serde_state(stateless)]
293        Deduplicated(HashConsId),
294        /// A plain value without an id.
295        Untagged(T),
296    }
297
298    impl<T> Serialize for HashConsed<T>
299    where
300        T: Serialize + HashConsable,
301    {
302        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
303        where
304            S: serde::Serializer,
305        {
306            SerRepr::Untagged(self.inner()).serialize(serializer)
307        }
308    }
309    /// Options for the state are `()` to serialize values normally and `HashConsDedupSerializer`
310    /// to deduplicate identical values in the serialized output.
311    impl<T, State> SerializeState<State> for HashConsed<T>
312    where
313        T: SerializeState<State> + HashConsable,
314        State: HashConsSerializerState,
315    {
316        fn serialize_state<S>(&self, state: &State, serializer: S) -> Result<S::Ok, S::Error>
317        where
318            S: serde::Serializer,
319        {
320            let hash_cons_id = self.id();
321            let repr = match state.record_serialized::<T>(hash_cons_id) {
322                Some(true) => SerRepr::HashConsedValue(hash_cons_id, self.inner()),
323                Some(false) => SerRepr::Deduplicated(hash_cons_id),
324                None => SerRepr::Untagged(self.inner()),
325            };
326            repr.serialize_state(state, serializer)
327        }
328    }
329
330    impl<'de, T> Deserialize<'de> for HashConsed<T>
331    where
332        T: Deserialize<'de> + HashConsable,
333    {
334        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
335        where
336            D: serde::Deserializer<'de>,
337        {
338            use serde::de::Error;
339            let repr: SerRepr<T> = SerRepr::deserialize(deserializer)?;
340            match repr {
341                SerRepr::HashConsedValue { .. } | SerRepr::Deduplicated { .. } => {
342                    let msg = format!(
343                        "trying to deserialize a deduplicated value using serde's `{ty}::deserialize` method. \
344                        This won't work, use serde_state's \
345                        `{ty}::deserialize_state(&HashConsDedupSerializer::default(), _)` instead",
346                        ty = type_name::<T>(),
347                    );
348                    Err(D::Error::custom(msg))
349                }
350                SerRepr::Untagged(val) => Ok(HashConsed::new(val)),
351            }
352        }
353    }
354    impl<'de, T, State> DeserializeState<'de, State> for HashConsed<T>
355    where
356        T: DeserializeState<'de, State> + HashConsable,
357        State: HashConsSerializerState,
358    {
359        fn deserialize_state<D>(state: &State, deserializer: D) -> Result<Self, D::Error>
360        where
361            D: serde::Deserializer<'de>,
362        {
363            use serde::de::Error;
364            let repr: SerRepr<T> = SerRepr::deserialize_state(state, deserializer)?;
365            Ok(match repr {
366                SerRepr::HashConsedValue(hash_cons_id, value) => {
367                    let val = HashConsed::new(value);
368                    state.record_deserialized(hash_cons_id, val.clone());
369                    val
370                }
371                SerRepr::Deduplicated(hash_cons_id) => {
372                    state.get_deserialized_val(hash_cons_id).ok_or_else(|| {
373                        let msg = format!(
374                            "can't deserialize deduplicated value of type {}; \
375                            were you careful with managing the deduplication state?",
376                            type_name::<T>()
377                        );
378                        D::Error::custom(msg)
379                    })?
380                }
381                SerRepr::Untagged(val) => HashConsed::new(val),
382            })
383        }
384    }
385}
386
387#[test]
388fn test_hash_cons() {
389    let x = HashConsed::new(42u32);
390    let y = HashConsed::new(42u32);
391    assert_eq!(x, y);
392    // Test a serialization round-trip.
393    let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
394    assert_eq!(x, z);
395}
396
397#[test]
398fn test_hash_cons_concurrent() {
399    use itertools::Itertools;
400    let handles = (0..10)
401        .map(|_| std::thread::spawn(|| std::hint::black_box(HashConsed::new(42u32))))
402        .collect_vec();
403    let values = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
404    assert!(values.iter().all_equal())
405}
406
407#[test]
408fn test_hash_cons_dedup() {
409    use serde_state::{DeserializeState, SerializeState};
410    type Ty = HashConsed<TyKind>;
411    #[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeState, DeserializeState)]
412    #[serde_state(state = HashConsDedupSerializer)]
413    enum TyKind {
414        Bool,
415        Pair(Ty, Ty),
416    }
417
418    // Build a value with some redundancy.
419    let bool1 = HashConsed::new(TyKind::Bool);
420    let bool2 = HashConsed::new(TyKind::Bool);
421    let pair = HashConsed::new(TyKind::Pair(bool1.clone(), bool2));
422    let triple = HashConsed::new(TyKind::Pair(bool1, pair));
423
424    let state = HashConsDedupSerializer::default();
425    let json_val = triple
426        .serialize_state(&state, serde_json::value::Serializer)
427        .unwrap();
428    let state = HashConsDedupSerializer::default();
429    let round_tripped = Ty::deserialize_state(&state, json_val).unwrap();
430
431    assert_eq!(triple, round_tripped);
432}