charon_lib/ast/
hash_cons.rs

1use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
2use serde::{Deserialize, Serialize};
3use std::hash::Hash;
4use std::ops::{ControlFlow, Deref};
5use std::sync::Arc;
6
7use crate::common::hash_by_addr::HashByAddr;
8use crate::common::type_map::Mappable;
9
10/// Hash-consed data structure: a reference-counted wrapper that guarantees that two equal
11/// value will be stored at the same address. This makes it possible to use the pointer address
12/// as a hash value.
13// Warning: a `derive` should not introduce a way to create a new `HashConsed` value without
14// going through the interning table.
15#[derive(PartialEq, Eq, Hash)]
16pub struct HashConsed<T>(HashByAddr<Arc<T>>);
17
18impl<T> Clone for HashConsed<T> {
19    fn clone(&self) -> Self {
20        Self(self.0.clone())
21    }
22}
23
24impl<T> HashConsed<T> {
25    pub fn inner(&self) -> &T {
26        self.0.0.as_ref()
27    }
28}
29
30pub trait HashConsable = Hash + PartialEq + Eq + Clone + Mappable;
31
32/// Unique id identifying a hashconsed value amongst those with the same type.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct HashConsId(usize);
35
36// Private module that contains the static we'll use as interning map. A value of type
37// `HashCons` MUST NOT be created in any other way than this table, else hashing and euqality
38// on it will be broken. Note that this likely means that if a crate uses charon both as a
39// direct dependency and as a dylib, then the static will be duplicated, causing hashing and
40// equality on `HashCons` to be broken.
41mod intern_table {
42    use indexmap::IndexSet;
43    use std::sync::{Arc, LazyLock, RwLock};
44
45    use super::{HashConsId, HashConsable, HashConsed};
46    use crate::common::hash_by_addr::HashByAddr;
47    use crate::common::type_map::{Mappable, Mapper, TypeMap};
48
49    // This is a static mutable `IndexSet<Arc<T>>` that records for each `T` value a unique
50    // `Arc<T>` that contains the same value. Values inside the set are hashed/compared
51    // as is normal for `T`.
52    // Once we've gotten an `Arc` out of the set however, we're sure that "T-equality"
53    // implies address-equality, hence the `HashByAddr` wrapper preserves correct equality
54    // and hashing behavior.
55    struct InternMapper;
56    impl Mapper for InternMapper {
57        type Value<T: Mappable> = IndexSet<Arc<T>>;
58    }
59    static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> = LazyLock::new(|| Default::default());
60
61    pub fn intern<T: HashConsable>(inner: T) -> HashConsed<T> {
62        // Fast read-only check.
63        #[expect(irrefutable_let_patterns)] // https://github.com/rust-lang/rust/issues/139369
64        let arc = if let read_guard = INTERNED.read().unwrap()
65            && let Some(set) = read_guard.get::<T>()
66            && let Some(arc) = set.get(&inner)
67        {
68            arc.clone()
69        } else {
70            // Concurrent access is possible right here, so we have to check everything again.
71            let mut write_guard = INTERNED.write().unwrap();
72            let set: &mut IndexSet<Arc<T>> = write_guard.or_default::<T>();
73            if let Some(arc) = set.get(&inner) {
74                arc.clone()
75            } else {
76                let arc: Arc<T> = Arc::new(inner);
77                set.insert(arc.clone());
78                arc
79            }
80        };
81        HashConsed(HashByAddr(arc))
82    }
83
84    /// Identify this value uniquely amongst values of its type. The id depends on insertion
85    /// order into the interning table which makes them in principle deterministic.
86    pub fn id<T: HashConsable>(x: &HashConsed<T>) -> HashConsId {
87        // `HashConsed` can only be constructed via `intern`, so we know this value exists in
88        // the table.
89        HashConsId(
90            (*INTERNED.read().unwrap())
91                .get::<T>()
92                .unwrap()
93                .get_index_of(&x.0.0)
94                .unwrap(),
95        )
96    }
97}
98
99impl<T> HashConsed<T>
100where
101    T: HashConsable,
102{
103    /// Deduplicate the values by hashing them. This deduplication is crucial for the hashing
104    /// function to be correct. This is the only function allowed to create `Self` values.
105    pub fn new(inner: T) -> Self {
106        intern_table::intern(inner)
107    }
108
109    /// Clones if needed to get mutable access to the inner value.
110    pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
111        // The value is behind a shared `Arc`, we clone it in order to mutate it.
112        let mut value = self.inner().clone();
113        let ret = f(&mut value);
114        // Re-intern the new value.
115        *self = Self::new(value);
116        ret
117    }
118
119    pub fn id(&self) -> HashConsId {
120        intern_table::id(self)
121    }
122}
123
124impl<T> Deref for HashConsed<T> {
125    type Target = T;
126    fn deref(&self) -> &Self::Target {
127        self.inner()
128    }
129}
130
131impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        // Hide the `HashByAddr` wrapper.
134        f.debug_tuple("HashConsed").field(self.inner()).finish()
135    }
136}
137
138impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
139    fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
140        v.visit(self.inner())
141    }
142}
143/// Note: this explores the inner value mutably by cloning and re-hashing afterwards.
144impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
145where
146    T: HashConsable,
147    V: for<'a> VisitMut<'a, T>,
148{
149    fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
150        self.with_inner_mut(|inner| v.visit(inner))
151    }
152}
153
154/// `HashCons` supports serializing each value to a unique id in order to serialize
155/// highly-shared values without explosion.
156///
157/// Note that the deduplication scheme is highly order-dependent: we serialize the real value
158/// the first time it comes up, and use ids only subsequent times. This relies on the fact that
159/// `derive(Serialize, Deserialize)` traverse the value in the same order.
160pub use serialize::{HashConsDedupSerializer, HashConsSerializerState};
161mod serialize {
162    use indexmap::IndexMap;
163    use serde::{Deserialize, Serialize};
164    use serde_state::{DeserializeState, SerializeState};
165    use std::any::type_name;
166    use std::cell::RefCell;
167    use std::collections::HashSet;
168
169    use super::{HashConsId, HashConsable, HashConsed};
170    use crate::common::type_map::{Mappable, Mapper, TypeMap};
171
172    pub trait HashConsSerializerState: Sized {
173        /// Record that this type is being serialized. Return `None` if we're not deduplicating
174        /// values, otherwise return whether this item was newly recorded.
175        fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool>;
176        /// Record that we deserialized this type.
177        fn record_deserialized<T: Mappable>(&self, id: HashConsId, value: HashConsed<T>);
178        /// Find the previously-deserialized type with that id.
179        fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>>;
180    }
181
182    impl HashConsSerializerState for () {
183        fn record_serialized<T: Mappable>(&self, _id: HashConsId) -> Option<bool> {
184            None
185        }
186        fn record_deserialized<T: Mappable>(&self, _id: HashConsId, _value: HashConsed<T>) {}
187        fn get_deserialized_val<T: Mappable>(&self, _id: HashConsId) -> Option<HashConsed<T>> {
188            None
189        }
190    }
191
192    struct SerializeTableMapper;
193    impl Mapper for SerializeTableMapper {
194        type Value<T: Mappable> = HashSet<HashConsId>;
195    }
196    struct DeserializeTableMapper;
197    impl Mapper for DeserializeTableMapper {
198        type Value<T: Mappable> = IndexMap<HashConsId, HashConsed<T>>;
199    }
200    #[derive(Default)]
201    pub struct HashConsDedupSerializer {
202        // Table used for serialization.
203        ser: RefCell<TypeMap<SerializeTableMapper>>,
204        // Table used for deserialization.
205        de: RefCell<TypeMap<DeserializeTableMapper>>,
206    }
207    impl HashConsSerializerState for HashConsDedupSerializer {
208        fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool> {
209            Some(self.ser.borrow_mut().or_default::<T>().insert(id))
210        }
211        fn record_deserialized<T: Mappable>(&self, id: HashConsId, val: HashConsed<T>) {
212            self.de.borrow_mut().or_default::<T>().insert(id, val);
213        }
214        fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>> {
215            self.de
216                .borrow()
217                .get::<T>()
218                .and_then(|map| map.get(&id))
219                .cloned()
220        }
221    }
222
223    /// A dummy enum used when serializing/deserializing a `HashConsed<T>`.
224    #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
225    #[serde_state(state_implements = HashConsSerializerState)]
226    enum SerRepr<T> {
227        /// A value represented normally, accompanied by its id. This is emitted the first time
228        /// we serialize a given value: subsequent times will use `SerRepr::Deduplicate`
229        /// instead.
230        HashConsedValue(#[serde_state(stateless)] HashConsId, T),
231        /// A value represented by its id. The actual value must have been emitted as a
232        /// `SerRepr::Value` with that same id earlier.
233        #[serde_state(stateless)]
234        Deduplicated(HashConsId),
235        /// A plain value without an id.
236        Untagged(T),
237    }
238
239    impl<T> Serialize for HashConsed<T>
240    where
241        T: Serialize + HashConsable,
242    {
243        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244        where
245            S: serde::Serializer,
246        {
247            SerRepr::Untagged(self.inner()).serialize(serializer)
248        }
249    }
250    /// Options for the state are `()` to serialize values normally and `HashConsDedupSerializer`
251    /// to deduplicate identical values in the serialized output.
252    impl<T, State> SerializeState<State> for HashConsed<T>
253    where
254        T: SerializeState<State> + HashConsable,
255        State: HashConsSerializerState,
256    {
257        fn serialize_state<S>(&self, state: &State, serializer: S) -> Result<S::Ok, S::Error>
258        where
259            S: serde::Serializer,
260        {
261            let hash_cons_id = self.id();
262            let repr = match state.record_serialized::<T>(hash_cons_id) {
263                Some(true) => SerRepr::HashConsedValue(hash_cons_id, self.inner()),
264                Some(false) => SerRepr::Deduplicated(hash_cons_id),
265                None => SerRepr::Untagged(self.inner()),
266            };
267            repr.serialize_state(state, serializer)
268        }
269    }
270
271    impl<'de, T> Deserialize<'de> for HashConsed<T>
272    where
273        T: Deserialize<'de> + HashConsable,
274    {
275        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
276        where
277            D: serde::Deserializer<'de>,
278        {
279            use serde::de::Error;
280            let repr: SerRepr<T> = SerRepr::deserialize(deserializer)?;
281            match repr {
282                SerRepr::HashConsedValue { .. } | SerRepr::Deduplicated { .. } => {
283                    let msg = format!(
284                        "trying to deserialize a deduplicated value using serde's `{ty}::deserialize` method. \
285                        This won't work, use serde_state's \
286                        `{ty}::deserialize_state(&HashConsDedupSerializer::default(), _)` instead",
287                        ty = type_name::<T>(),
288                    );
289                    Err(D::Error::custom(msg))
290                }
291                SerRepr::Untagged(val) => Ok(HashConsed::new(val)),
292            }
293        }
294    }
295    impl<'de, T, State> DeserializeState<'de, State> for HashConsed<T>
296    where
297        T: DeserializeState<'de, State> + HashConsable,
298        State: HashConsSerializerState,
299    {
300        fn deserialize_state<D>(state: &State, deserializer: D) -> Result<Self, D::Error>
301        where
302            D: serde::Deserializer<'de>,
303        {
304            use serde::de::Error;
305            let repr: SerRepr<T> = SerRepr::deserialize_state(state, deserializer)?;
306            Ok(match repr {
307                SerRepr::HashConsedValue(hash_cons_id, value) => {
308                    let val = HashConsed::new(value);
309                    state.record_deserialized(hash_cons_id, val.clone());
310                    val
311                }
312                SerRepr::Deduplicated(hash_cons_id) => {
313                    state.get_deserialized_val(hash_cons_id).ok_or_else(|| {
314                        let msg = format!(
315                            "can't deserialize deduplicated value of type {}; \
316                            were you careful with managing the deduplication state?",
317                            type_name::<T>()
318                        );
319                        D::Error::custom(msg)
320                    })?
321                }
322                SerRepr::Untagged(val) => HashConsed::new(val),
323            })
324        }
325    }
326}
327
328#[test]
329fn test_hash_cons() {
330    let x = HashConsed::new(42u32);
331    let y = HashConsed::new(42u32);
332    assert_eq!(x, y);
333    // Test a serialization round-trip.
334    let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
335    assert_eq!(x, z);
336}
337
338#[test]
339fn test_hash_cons_concurrent() {
340    use itertools::Itertools;
341    let handles = (0..10)
342        .into_iter()
343        .map(|_| std::thread::spawn(|| std::hint::black_box(HashConsed::new(42u32))))
344        .collect_vec();
345    let values = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
346    assert!(values.iter().all_equal())
347}
348
349#[test]
350fn test_hash_cons_dedup() {
351    use serde_state::{DeserializeState, SerializeState};
352    type Ty = HashConsed<TyKind>;
353    #[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeState, DeserializeState)]
354    #[serde_state(state = HashConsDedupSerializer)]
355    enum TyKind {
356        Bool,
357        Pair(Ty, Ty),
358    }
359
360    // Build a value with some redundancy.
361    let bool1 = HashConsed::new(TyKind::Bool);
362    let bool2 = HashConsed::new(TyKind::Bool);
363    let pair = HashConsed::new(TyKind::Pair(bool1.clone(), bool2));
364    let triple = HashConsed::new(TyKind::Pair(bool1, pair));
365
366    let state = HashConsDedupSerializer::default();
367    let json_val = triple
368        .serialize_state(&state, serde_json::value::Serializer)
369        .unwrap();
370    let state = HashConsDedupSerializer::default();
371    let round_tripped = Ty::deserialize_state(&state, json_val).unwrap();
372
373    assert_eq!(triple, round_tripped);
374}