Skip to main content

charon_lib/ast/
hash_cons.rs

1use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
2use serde::{Deserialize, Serialize};
3use std::hash::Hash;
4use std::ops::{ControlFlow, Deref};
5use std::sync::Arc;
6
7use crate::common::hash_by_addr::HashByAddr;
8use crate::common::type_map::Mappable;
9
10/// Hash-consed data structure: a reference-counted wrapper that guarantees that two equal
11/// value will be stored at the same address. This makes it possible to use the pointer address
12/// as a hash value.
13// Warning: a `derive` should not introduce a way to create a new `HashConsed` value without
14// going through the interning table.
15#[derive(PartialEq, Eq, Hash)]
16pub struct HashConsed<T>(HashByAddr<Arc<T>>);
17
18impl<T> Clone for HashConsed<T> {
19    fn clone(&self) -> Self {
20        Self(self.0.clone())
21    }
22}
23
24impl<T> HashConsed<T> {
25    pub fn inner(&self) -> &T {
26        self.0.0.as_ref()
27    }
28}
29
30pub trait HashConsable = Hash + PartialEq + Eq + Clone + Mappable;
31
32/// Unique id identifying a hashconsed value amongst all hashconsed values.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct HashConsId(u64);
35
36// Private module that contains the static we'll use as interning map. A value of type
37// `HashCons` MUST NOT be created in any other way than this table, else hashing and euqality
38// on it will be broken. Note that this likely means that if a crate uses charon both as a
39// direct dependency and as a dylib, then the static will be duplicated, causing hashing and
40// equality on `HashCons` to be broken.
41mod intern_table {
42    use indexmap::IndexMap as SeqHashMap;
43    use std::borrow::Borrow;
44    use std::sync::atomic::{AtomicU64, Ordering};
45    use std::sync::{Arc, LazyLock, RwLock};
46
47    use super::{HashConsId, HashConsable, HashConsed};
48    use crate::common::hash_by_addr::HashByAddr;
49    use crate::common::type_map::{Mappable, Mapper, TypeMap};
50
51    // Only way we create a `HashConsId`.
52    fn fresh_id() -> HashConsId {
53        static ID: AtomicU64 = AtomicU64::new(0);
54        HashConsId(ID.fetch_add(1, Ordering::Relaxed))
55    }
56
57    // This is a static mutable `SeqHashSet<Arc<T>>` that records for each `T` value a unique
58    // `Arc<T>` that contains the same value. Values inside the set are hashed/compared
59    // as is normal for `T`.
60    // Once we've gotten an `Arc` out of the set however, we're sure that "T-equality"
61    // implies address-equality, hence the `HashByAddr` wrapper preserves correct equality
62    // and hashing behavior.
63    // Note that we also store a `HashConsId` for each item so this is a map instead of a set, but
64    // what matters is really the map keys.
65    struct InternMapper;
66    impl Mapper for InternMapper {
67        type Value<T: Mappable> = SeqHashMap<Arc<T>, HashConsId>;
68    }
69    static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> = LazyLock::new(Default::default);
70
71    // The excessive generality is to make it work for both `U = T` and `U = Arc<T>`.
72    pub fn intern<T: HashConsable, U>(inner: U) -> HashConsed<T>
73    where
74        Arc<T>: Borrow<U>,
75        U: Into<Arc<T>> + std::hash::Hash,
76        U: indexmap::Equivalent<Arc<T>>,
77    {
78        // Fast read-only check.
79        #[expect(irrefutable_let_patterns)] // https://github.com/rust-lang/rust/issues/139369
80        let arc = if let read_guard = INTERNED.read().unwrap()
81            && let Some(map) = read_guard.get::<T>()
82            && let Some((arc, _id)) = map.get_key_value(&inner)
83        {
84            arc.clone()
85        } else {
86            // Concurrent access is possible right here, so we have to check everything again.
87            let mut write_guard = INTERNED.write().unwrap();
88            let map: &mut SeqHashMap<Arc<T>, _> = write_guard.or_default::<T>();
89            if let Some((arc, _id)) = map.get_key_value(&inner) {
90                arc.clone()
91            } else {
92                let arc: Arc<T> = inner.into();
93                map.insert(arc.clone(), fresh_id());
94                arc
95            }
96        };
97        HashConsed(HashByAddr(arc))
98    }
99
100    /// Mutate the contents in-place if possible.
101    pub fn mutate_in_place<T: HashConsable, R, F: FnOnce(&mut T) -> R>(
102        x: &mut HashConsed<T>,
103        f: F,
104    ) -> Result<R, F> {
105        let arc = &mut x.0.0;
106        // Every value has at least two pointers: the current value and the one stored in the
107        // global map. If there are exactly two, we may mutate directly by discarding the one in
108        // the global map temporarily.
109        if Arc::strong_count(arc) != 2 {
110            return Err(f);
111        }
112        {
113            // Take the write guard just long enough to drop the other `Arc` to this value.
114            let mut write_guard = INTERNED.write().unwrap();
115            if let Some((other_arc, _)) = write_guard.or_default::<T>().swap_remove_entry(&*arc) {
116                drop(other_arc);
117            } else {
118                // Nothing was removed, early return.
119                return Err(f);
120            }
121            // The Arc was removed from the map; `x` is invalid as interning the same value would
122            // result in a different pointer. NO MORE EARLY RETURN until we fix that.
123        }
124        // If we are still the sole owner, we can now mutate in-place.
125        let ret = match Arc::get_mut(arc) {
126            Some(val) => Ok(f(val)),
127            None => Err(f),
128        };
129        // Re-establish the interning invariant. If the same value was added to the map in the
130        // meantime, we'll get a pointer to that.
131        *x = HashConsed::from_arc(arc.clone());
132        ret
133    }
134
135    /// Identify this value uniquely amongst values of its type. The id depends on insertion
136    /// order into the interning table which makes them in principle deterministic.
137    pub fn id<T: HashConsable>(x: &HashConsed<T>) -> HashConsId {
138        // `HashConsed` can only be constructed via `intern`, so we know this value exists in the
139        // table.
140        let read_guard = INTERNED.read().unwrap();
141        let map = read_guard.get::<T>().unwrap();
142        let (_arc, id) = map.get_key_value(&x.0.0).unwrap();
143        *id
144    }
145}
146
147impl<T> HashConsed<T>
148where
149    T: HashConsable,
150{
151    /// Deduplicate the values by hashing them. This deduplication is crucial for the hashing
152    /// function to be correct. This is the only function allowed to create `Self` values.
153    pub fn new(inner: T) -> Self {
154        intern_table::intern(inner)
155    }
156    /// Rarely used: in case we already have an `Arc`, may avoid an allocation.
157    pub fn from_arc(inner: Arc<T>) -> Self {
158        intern_table::intern(inner)
159    }
160
161    /// Clones if needed to get mutable access to the inner value.
162    pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
163        match intern_table::mutate_in_place(self, f) {
164            Ok(r) => r,
165            Err(f) => {
166                // The value is behind a shared `Arc`, we clone it in order to mutate it.
167                let mut value = self.inner().clone();
168                let ret = f(&mut value);
169                // Re-intern the new value.
170                *self = Self::new(value);
171                ret
172            }
173        }
174    }
175
176    pub fn id(&self) -> HashConsId {
177        intern_table::id(self)
178    }
179}
180
181impl<T> Deref for HashConsed<T> {
182    type Target = T;
183    fn deref(&self) -> &Self::Target {
184        self.inner()
185    }
186}
187
188impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        // Hide the `HashByAddr` wrapper.
191        f.debug_tuple("HashConsed").field(self.inner()).finish()
192    }
193}
194
195impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
196    fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
197        v.visit(self.inner())
198    }
199}
200/// Note: this explores the inner value mutably by cloning and re-hashing afterwards.
201impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
202where
203    T: HashConsable,
204    V: for<'a> VisitMut<'a, T>,
205{
206    fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
207        self.with_inner_mut(|inner| v.visit(inner))
208    }
209}
210
211/// `HashCons` supports serializing each value to a unique id in order to serialize
212/// highly-shared values without explosion.
213///
214/// Note that the deduplication scheme is highly order-dependent: we serialize the real value
215/// the first time it comes up, and use ids only subsequent times. This relies on the fact that
216/// `derive(Serialize, Deserialize)` traverse the value in the same order.
217pub use serialize::{HashConsDedupSerializer, HashConsSerializerState};
218mod serialize {
219    use indexmap::IndexMap as SeqHashMap;
220    use serde::{Deserialize, Serialize};
221    use serde_state::{DeserializeState, SerializeState};
222    use std::any::type_name;
223    use std::cell::RefCell;
224    use std::collections::HashSet;
225
226    use super::{HashConsId, HashConsable, HashConsed};
227    use crate::common::type_map::{Mappable, Mapper, TypeMap};
228
229    pub trait HashConsSerializerState: Sized {
230        /// Record that this type is being serialized. Return `None` if we're not deduplicating
231        /// values, otherwise return whether this item was newly recorded.
232        fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool>;
233        /// Record that we deserialized this type.
234        fn record_deserialized<T: Mappable>(&self, id: HashConsId, value: HashConsed<T>);
235        /// Find the previously-deserialized type with that id.
236        fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>>;
237    }
238
239    impl HashConsSerializerState for () {
240        fn record_serialized<T: Mappable>(&self, _id: HashConsId) -> Option<bool> {
241            None
242        }
243        fn record_deserialized<T: Mappable>(&self, _id: HashConsId, _value: HashConsed<T>) {}
244        fn get_deserialized_val<T: Mappable>(&self, _id: HashConsId) -> Option<HashConsed<T>> {
245            None
246        }
247    }
248
249    struct SerializeTableMapper;
250    impl Mapper for SerializeTableMapper {
251        type Value<T: Mappable> = HashSet<HashConsId>;
252    }
253    struct DeserializeTableMapper;
254    impl Mapper for DeserializeTableMapper {
255        type Value<T: Mappable> = SeqHashMap<HashConsId, HashConsed<T>>;
256    }
257    #[derive(Default)]
258    pub struct HashConsDedupSerializer {
259        // Table used for serialization.
260        ser: RefCell<TypeMap<SerializeTableMapper>>,
261        // Table used for deserialization.
262        de: RefCell<TypeMap<DeserializeTableMapper>>,
263    }
264    impl HashConsSerializerState for HashConsDedupSerializer {
265        fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool> {
266            Some(self.ser.borrow_mut().or_default::<T>().insert(id))
267        }
268        fn record_deserialized<T: Mappable>(&self, id: HashConsId, val: HashConsed<T>) {
269            self.de.borrow_mut().or_default::<T>().insert(id, val);
270        }
271        fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>> {
272            self.de
273                .borrow()
274                .get::<T>()
275                .and_then(|map| map.get(&id))
276                .cloned()
277        }
278    }
279
280    /// A dummy enum used when serializing/deserializing a `HashConsed<T>`.
281    #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
282    #[serde_state(state_implements = HashConsSerializerState)]
283    enum SerRepr<T> {
284        /// A value represented normally, accompanied by its id. This is emitted the first time
285        /// we serialize a given value: subsequent times will use `SerRepr::Deduplicate`
286        /// instead.
287        HashConsedValue(#[serde_state(stateless)] HashConsId, T),
288        /// A value represented by its id. The actual value must have been emitted as a
289        /// `SerRepr::Value` with that same id earlier.
290        #[serde_state(stateless)]
291        Deduplicated(HashConsId),
292        /// A plain value without an id.
293        Untagged(T),
294    }
295
296    impl<T> Serialize for HashConsed<T>
297    where
298        T: Serialize + HashConsable,
299    {
300        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
301        where
302            S: serde::Serializer,
303        {
304            SerRepr::Untagged(self.inner()).serialize(serializer)
305        }
306    }
307    /// Options for the state are `()` to serialize values normally and `HashConsDedupSerializer`
308    /// to deduplicate identical values in the serialized output.
309    impl<T, State> SerializeState<State> for HashConsed<T>
310    where
311        T: SerializeState<State> + HashConsable,
312        State: HashConsSerializerState,
313    {
314        fn serialize_state<S>(&self, state: &State, serializer: S) -> Result<S::Ok, S::Error>
315        where
316            S: serde::Serializer,
317        {
318            let hash_cons_id = self.id();
319            let repr = match state.record_serialized::<T>(hash_cons_id) {
320                Some(true) => SerRepr::HashConsedValue(hash_cons_id, self.inner()),
321                Some(false) => SerRepr::Deduplicated(hash_cons_id),
322                None => SerRepr::Untagged(self.inner()),
323            };
324            repr.serialize_state(state, serializer)
325        }
326    }
327
328    impl<'de, T> Deserialize<'de> for HashConsed<T>
329    where
330        T: Deserialize<'de> + HashConsable,
331    {
332        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
333        where
334            D: serde::Deserializer<'de>,
335        {
336            use serde::de::Error;
337            let repr: SerRepr<T> = SerRepr::deserialize(deserializer)?;
338            match repr {
339                SerRepr::HashConsedValue { .. } | SerRepr::Deduplicated { .. } => {
340                    let msg = format!(
341                        "trying to deserialize a deduplicated value using serde's `{ty}::deserialize` method. \
342                        This won't work, use serde_state's \
343                        `{ty}::deserialize_state(&HashConsDedupSerializer::default(), _)` instead",
344                        ty = type_name::<T>(),
345                    );
346                    Err(D::Error::custom(msg))
347                }
348                SerRepr::Untagged(val) => Ok(HashConsed::new(val)),
349            }
350        }
351    }
352    impl<'de, T, State> DeserializeState<'de, State> for HashConsed<T>
353    where
354        T: DeserializeState<'de, State> + HashConsable,
355        State: HashConsSerializerState,
356    {
357        fn deserialize_state<D>(state: &State, deserializer: D) -> Result<Self, D::Error>
358        where
359            D: serde::Deserializer<'de>,
360        {
361            use serde::de::Error;
362            let repr: SerRepr<T> = SerRepr::deserialize_state(state, deserializer)?;
363            Ok(match repr {
364                SerRepr::HashConsedValue(hash_cons_id, value) => {
365                    let val = HashConsed::new(value);
366                    state.record_deserialized(hash_cons_id, val.clone());
367                    val
368                }
369                SerRepr::Deduplicated(hash_cons_id) => {
370                    state.get_deserialized_val(hash_cons_id).ok_or_else(|| {
371                        let msg = format!(
372                            "can't deserialize deduplicated value of type {}; \
373                            were you careful with managing the deduplication state?",
374                            type_name::<T>()
375                        );
376                        D::Error::custom(msg)
377                    })?
378                }
379                SerRepr::Untagged(val) => HashConsed::new(val),
380            })
381        }
382    }
383}
384
385#[test]
386fn test_hash_cons() {
387    let x = HashConsed::new(42u32);
388    let y = HashConsed::new(42u32);
389    assert_eq!(x, y);
390    // Test a serialization round-trip.
391    let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
392    assert_eq!(x, z);
393}
394
395#[test]
396fn test_hash_cons_concurrent() {
397    use itertools::Itertools;
398    let handles = (0..10)
399        .map(|_| std::thread::spawn(|| std::hint::black_box(HashConsed::new(42u32))))
400        .collect_vec();
401    let values = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
402    assert!(values.iter().all_equal())
403}
404
405#[test]
406fn test_hash_cons_dedup() {
407    use serde_state::{DeserializeState, SerializeState};
408    type Ty = HashConsed<TyKind>;
409    #[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeState, DeserializeState)]
410    #[serde_state(state = HashConsDedupSerializer)]
411    enum TyKind {
412        Bool,
413        Pair(Ty, Ty),
414    }
415
416    // Build a value with some redundancy.
417    let bool1 = HashConsed::new(TyKind::Bool);
418    let bool2 = HashConsed::new(TyKind::Bool);
419    let pair = HashConsed::new(TyKind::Pair(bool1.clone(), bool2));
420    let triple = HashConsed::new(TyKind::Pair(bool1, pair));
421
422    let state = HashConsDedupSerializer::default();
423    let json_val = triple
424        .serialize_state(&state, serde_json::value::Serializer)
425        .unwrap();
426    let state = HashConsDedupSerializer::default();
427    let round_tripped = Ty::deserialize_state(&state, json_val).unwrap();
428
429    assert_eq!(triple, round_tripped);
430}