1use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
2use serde::{Deserialize, Serialize};
3use std::hash::Hash;
4use std::ops::{ControlFlow, Deref};
5use std::sync::Arc;
6
7use crate::common::hash_by_addr::HashByAddr;
8use crate::common::type_map::Mappable;
9
10#[derive(PartialEq, Eq, Hash)]
16pub struct HashConsed<T>(HashByAddr<Arc<T>>);
17
18impl<T> Clone for HashConsed<T> {
19 fn clone(&self) -> Self {
20 Self(self.0.clone())
21 }
22}
23
24impl<T> HashConsed<T> {
25 pub fn inner(&self) -> &T {
26 self.0.0.as_ref()
27 }
28}
29
30pub trait HashConsable: Hash + PartialEq + Eq + Clone + Mappable {}
31impl<T> HashConsable for T where T: Hash + PartialEq + Eq + Clone + Mappable {}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
35pub struct HashConsId(u64);
36
37mod intern_table {
43 use rustc_hash::FxBuildHasher;
44 use std::borrow::Borrow;
45 use std::sync::atomic::{AtomicU64, Ordering};
46 use std::sync::{Arc, LazyLock, RwLock};
47
48 use super::{HashConsId, HashConsable, HashConsed};
49 use crate::common::hash_by_addr::HashByAddr;
50 use crate::common::type_map::{Mappable, Mapper, TypeMap};
51
52 type SeqHashMap<K, V> = indexmap::IndexMap<K, V, FxBuildHasher>;
53
54 fn fresh_id() -> HashConsId {
56 static ID: AtomicU64 = AtomicU64::new(0);
57 HashConsId(ID.fetch_add(1, Ordering::Relaxed))
58 }
59
60 struct InternMapper;
69 impl Mapper for InternMapper {
70 type Value<T: Mappable> = SeqHashMap<Arc<T>, HashConsId>;
71 }
72 static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> = LazyLock::new(Default::default);
73
74 pub fn intern<T: HashConsable, U>(inner: U) -> HashConsed<T>
76 where
77 Arc<T>: Borrow<U>,
78 U: Into<Arc<T>> + std::hash::Hash,
79 U: indexmap::Equivalent<Arc<T>>,
80 {
81 let arc = if let read_guard = INTERNED.read().unwrap()
83 && let Some(map) = read_guard.get::<T>()
84 && let Some((arc, _id)) = map.get_key_value(&inner)
85 {
86 arc.clone()
87 } else {
88 let mut write_guard = INTERNED.write().unwrap();
90 let map: &mut SeqHashMap<Arc<T>, _> = write_guard.or_default::<T>();
91 if let Some((arc, _id)) = map.get_key_value(&inner) {
92 arc.clone()
93 } else {
94 let arc: Arc<T> = inner.into();
95 map.insert(arc.clone(), fresh_id());
96 arc
97 }
98 };
99 HashConsed(HashByAddr(arc))
100 }
101
102 pub fn mutate_in_place<T: HashConsable, R, F: FnOnce(&mut T) -> R>(
104 x: &mut HashConsed<T>,
105 f: F,
106 ) -> Result<R, F> {
107 let arc = &mut x.0.0;
108 if Arc::strong_count(arc) != 2 {
112 return Err(f);
113 }
114 {
115 let mut write_guard = INTERNED.write().unwrap();
117 if let Some((other_arc, _)) = write_guard.or_default::<T>().swap_remove_entry(&*arc) {
118 drop(other_arc);
119 } else {
120 return Err(f);
122 }
123 }
126 let ret = match Arc::get_mut(arc) {
128 Some(val) => Ok(f(val)),
129 None => Err(f),
130 };
131 *x = HashConsed::from_arc(arc.clone());
134 ret
135 }
136
137 pub fn id<T: HashConsable>(x: &HashConsed<T>) -> HashConsId {
140 let read_guard = INTERNED.read().unwrap();
143 let map = read_guard.get::<T>().unwrap();
144 let (_arc, id) = map.get_key_value(&x.0.0).unwrap();
145 *id
146 }
147}
148
149impl<T> HashConsed<T>
150where
151 T: HashConsable,
152{
153 pub fn new(inner: T) -> Self {
156 intern_table::intern(inner)
157 }
158 pub fn from_arc(inner: Arc<T>) -> Self {
160 intern_table::intern(inner)
161 }
162
163 pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
165 match intern_table::mutate_in_place(self, f) {
166 Ok(r) => r,
167 Err(f) => {
168 let mut value = self.inner().clone();
170 let ret = f(&mut value);
171 *self = Self::new(value);
173 ret
174 }
175 }
176 }
177
178 pub fn id(&self) -> HashConsId {
179 intern_table::id(self)
180 }
181}
182
183impl<T> Deref for HashConsed<T> {
184 type Target = T;
185 fn deref(&self) -> &Self::Target {
186 self.inner()
187 }
188}
189
190impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.debug_tuple("HashConsed").field(self.inner()).finish()
194 }
195}
196
197impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
198 fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
199 v.visit(self.inner())
200 }
201}
202impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
204where
205 T: HashConsable,
206 V: for<'a> VisitMut<'a, T>,
207{
208 fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
209 self.with_inner_mut(|inner| v.visit(inner))
210 }
211}
212
213pub use serialize::{HashConsDedupSerializer, HashConsSerializerState};
220mod serialize {
221 use indexmap::IndexMap as SeqHashMap;
222 use serde::{Deserialize, Serialize};
223 use serde_state::{DeserializeState, SerializeState};
224 use std::any::type_name;
225 use std::cell::RefCell;
226 use std::collections::HashSet;
227
228 use super::{HashConsId, HashConsable, HashConsed};
229 use crate::common::type_map::{Mappable, Mapper, TypeMap};
230
231 pub trait HashConsSerializerState: Sized {
232 fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool>;
235 fn record_deserialized<T: Mappable>(&self, id: HashConsId, value: HashConsed<T>);
237 fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>>;
239 }
240
241 impl HashConsSerializerState for () {
242 fn record_serialized<T: Mappable>(&self, _id: HashConsId) -> Option<bool> {
243 None
244 }
245 fn record_deserialized<T: Mappable>(&self, _id: HashConsId, _value: HashConsed<T>) {}
246 fn get_deserialized_val<T: Mappable>(&self, _id: HashConsId) -> Option<HashConsed<T>> {
247 None
248 }
249 }
250
251 struct SerializeTableMapper;
252 impl Mapper for SerializeTableMapper {
253 type Value<T: Mappable> = HashSet<HashConsId>;
254 }
255 struct DeserializeTableMapper;
256 impl Mapper for DeserializeTableMapper {
257 type Value<T: Mappable> = SeqHashMap<HashConsId, HashConsed<T>>;
258 }
259 #[derive(Default)]
260 pub struct HashConsDedupSerializer {
261 ser: RefCell<TypeMap<SerializeTableMapper>>,
263 de: RefCell<TypeMap<DeserializeTableMapper>>,
265 }
266 impl HashConsSerializerState for HashConsDedupSerializer {
267 fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool> {
268 Some(self.ser.borrow_mut().or_default::<T>().insert(id))
269 }
270 fn record_deserialized<T: Mappable>(&self, id: HashConsId, val: HashConsed<T>) {
271 self.de.borrow_mut().or_default::<T>().insert(id, val);
272 }
273 fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>> {
274 self.de
275 .borrow()
276 .get::<T>()
277 .and_then(|map| map.get(&id))
278 .cloned()
279 }
280 }
281
282 #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
284 #[serde_state(state_implements = HashConsSerializerState)]
285 enum SerRepr<T> {
286 HashConsedValue(#[serde_state(stateless)] HashConsId, T),
290 #[serde_state(stateless)]
293 Deduplicated(HashConsId),
294 Untagged(T),
296 }
297
298 impl<T> Serialize for HashConsed<T>
299 where
300 T: Serialize + HashConsable,
301 {
302 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
303 where
304 S: serde::Serializer,
305 {
306 SerRepr::Untagged(self.inner()).serialize(serializer)
307 }
308 }
309 impl<T, State> SerializeState<State> for HashConsed<T>
312 where
313 T: SerializeState<State> + HashConsable,
314 State: HashConsSerializerState,
315 {
316 fn serialize_state<S>(&self, state: &State, serializer: S) -> Result<S::Ok, S::Error>
317 where
318 S: serde::Serializer,
319 {
320 let hash_cons_id = self.id();
321 let repr = match state.record_serialized::<T>(hash_cons_id) {
322 Some(true) => SerRepr::HashConsedValue(hash_cons_id, self.inner()),
323 Some(false) => SerRepr::Deduplicated(hash_cons_id),
324 None => SerRepr::Untagged(self.inner()),
325 };
326 repr.serialize_state(state, serializer)
327 }
328 }
329
330 impl<'de, T> Deserialize<'de> for HashConsed<T>
331 where
332 T: Deserialize<'de> + HashConsable,
333 {
334 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
335 where
336 D: serde::Deserializer<'de>,
337 {
338 use serde::de::Error;
339 let repr: SerRepr<T> = SerRepr::deserialize(deserializer)?;
340 match repr {
341 SerRepr::HashConsedValue { .. } | SerRepr::Deduplicated { .. } => {
342 let msg = format!(
343 "trying to deserialize a deduplicated value using serde's `{ty}::deserialize` method. \
344 This won't work, use serde_state's \
345 `{ty}::deserialize_state(&HashConsDedupSerializer::default(), _)` instead",
346 ty = type_name::<T>(),
347 );
348 Err(D::Error::custom(msg))
349 }
350 SerRepr::Untagged(val) => Ok(HashConsed::new(val)),
351 }
352 }
353 }
354 impl<'de, T, State> DeserializeState<'de, State> for HashConsed<T>
355 where
356 T: DeserializeState<'de, State> + HashConsable,
357 State: HashConsSerializerState,
358 {
359 fn deserialize_state<D>(state: &State, deserializer: D) -> Result<Self, D::Error>
360 where
361 D: serde::Deserializer<'de>,
362 {
363 use serde::de::Error;
364 let repr: SerRepr<T> = SerRepr::deserialize_state(state, deserializer)?;
365 Ok(match repr {
366 SerRepr::HashConsedValue(hash_cons_id, value) => {
367 let val = HashConsed::new(value);
368 state.record_deserialized(hash_cons_id, val.clone());
369 val
370 }
371 SerRepr::Deduplicated(hash_cons_id) => {
372 state.get_deserialized_val(hash_cons_id).ok_or_else(|| {
373 let msg = format!(
374 "can't deserialize deduplicated value of type {}; \
375 were you careful with managing the deduplication state?",
376 type_name::<T>()
377 );
378 D::Error::custom(msg)
379 })?
380 }
381 SerRepr::Untagged(val) => HashConsed::new(val),
382 })
383 }
384 }
385}
386
387#[test]
388fn test_hash_cons() {
389 let x = HashConsed::new(42u32);
390 let y = HashConsed::new(42u32);
391 assert_eq!(x, y);
392 let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
394 assert_eq!(x, z);
395}
396
397#[test]
398fn test_hash_cons_concurrent() {
399 use itertools::Itertools;
400 let handles = (0..10)
401 .map(|_| std::thread::spawn(|| std::hint::black_box(HashConsed::new(42u32))))
402 .collect_vec();
403 let values = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
404 assert!(values.iter().all_equal())
405}
406
407#[test]
408fn test_hash_cons_dedup() {
409 use serde_state::{DeserializeState, SerializeState};
410 type Ty = HashConsed<TyKind>;
411 #[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeState, DeserializeState)]
412 #[serde_state(state = HashConsDedupSerializer)]
413 enum TyKind {
414 Bool,
415 Pair(Ty, Ty),
416 }
417
418 let bool1 = HashConsed::new(TyKind::Bool);
420 let bool2 = HashConsed::new(TyKind::Bool);
421 let pair = HashConsed::new(TyKind::Pair(bool1.clone(), bool2));
422 let triple = HashConsed::new(TyKind::Pair(bool1, pair));
423
424 let state = HashConsDedupSerializer::default();
425 let json_val = triple
426 .serialize_state(&state, serde_json::value::Serializer)
427 .unwrap();
428 let state = HashConsDedupSerializer::default();
429 let round_tripped = Ty::deserialize_state(&state, json_val).unwrap();
430
431 assert_eq!(triple, round_tripped);
432}