1use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
2use serde::{Deserialize, Serialize};
3use std::hash::Hash;
4use std::ops::{ControlFlow, Deref};
5use std::sync::Arc;
6
7use crate::common::hash_by_addr::HashByAddr;
8use crate::common::type_map::Mappable;
9
10#[derive(PartialEq, Eq, Hash)]
16pub struct HashConsed<T>(HashByAddr<Arc<T>>);
17
18impl<T> Clone for HashConsed<T> {
19 fn clone(&self) -> Self {
20 Self(self.0.clone())
21 }
22}
23
24impl<T> HashConsed<T> {
25 pub fn inner(&self) -> &T {
26 self.0.0.as_ref()
27 }
28}
29
30pub trait HashConsable = Hash + PartialEq + Eq + Clone + Mappable;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct HashConsId(u64);
35
36mod intern_table {
42 use indexmap::IndexMap as SeqHashMap;
43 use std::borrow::Borrow;
44 use std::sync::atomic::{AtomicU64, Ordering};
45 use std::sync::{Arc, LazyLock, RwLock};
46
47 use super::{HashConsId, HashConsable, HashConsed};
48 use crate::common::hash_by_addr::HashByAddr;
49 use crate::common::type_map::{Mappable, Mapper, TypeMap};
50
51 fn fresh_id() -> HashConsId {
53 static ID: AtomicU64 = AtomicU64::new(0);
54 HashConsId(ID.fetch_add(1, Ordering::Relaxed))
55 }
56
57 struct InternMapper;
66 impl Mapper for InternMapper {
67 type Value<T: Mappable> = SeqHashMap<Arc<T>, HashConsId>;
68 }
69 static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> = LazyLock::new(Default::default);
70
71 pub fn intern<T: HashConsable, U>(inner: U) -> HashConsed<T>
73 where
74 Arc<T>: Borrow<U>,
75 U: Into<Arc<T>> + std::hash::Hash,
76 U: indexmap::Equivalent<Arc<T>>,
77 {
78 #[expect(irrefutable_let_patterns)] let arc = if let read_guard = INTERNED.read().unwrap()
81 && let Some(map) = read_guard.get::<T>()
82 && let Some((arc, _id)) = map.get_key_value(&inner)
83 {
84 arc.clone()
85 } else {
86 let mut write_guard = INTERNED.write().unwrap();
88 let map: &mut SeqHashMap<Arc<T>, _> = write_guard.or_default::<T>();
89 if let Some((arc, _id)) = map.get_key_value(&inner) {
90 arc.clone()
91 } else {
92 let arc: Arc<T> = inner.into();
93 map.insert(arc.clone(), fresh_id());
94 arc
95 }
96 };
97 HashConsed(HashByAddr(arc))
98 }
99
100 pub fn mutate_in_place<T: HashConsable, R, F: FnOnce(&mut T) -> R>(
102 x: &mut HashConsed<T>,
103 f: F,
104 ) -> Result<R, F> {
105 let arc = &mut x.0.0;
106 if Arc::strong_count(arc) != 2 {
110 return Err(f);
111 }
112 {
113 let mut write_guard = INTERNED.write().unwrap();
115 if let Some((other_arc, _)) = write_guard.or_default::<T>().swap_remove_entry(&*arc) {
116 drop(other_arc);
117 } else {
118 return Err(f);
120 }
121 }
124 let ret = match Arc::get_mut(arc) {
126 Some(val) => Ok(f(val)),
127 None => Err(f),
128 };
129 *x = HashConsed::from_arc(arc.clone());
132 ret
133 }
134
135 pub fn id<T: HashConsable>(x: &HashConsed<T>) -> HashConsId {
138 let read_guard = INTERNED.read().unwrap();
141 let map = read_guard.get::<T>().unwrap();
142 let (_arc, id) = map.get_key_value(&x.0.0).unwrap();
143 *id
144 }
145}
146
147impl<T> HashConsed<T>
148where
149 T: HashConsable,
150{
151 pub fn new(inner: T) -> Self {
154 intern_table::intern(inner)
155 }
156 pub fn from_arc(inner: Arc<T>) -> Self {
158 intern_table::intern(inner)
159 }
160
161 pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
163 match intern_table::mutate_in_place(self, f) {
164 Ok(r) => r,
165 Err(f) => {
166 let mut value = self.inner().clone();
168 let ret = f(&mut value);
169 *self = Self::new(value);
171 ret
172 }
173 }
174 }
175
176 pub fn id(&self) -> HashConsId {
177 intern_table::id(self)
178 }
179}
180
181impl<T> Deref for HashConsed<T> {
182 type Target = T;
183 fn deref(&self) -> &Self::Target {
184 self.inner()
185 }
186}
187
188impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_tuple("HashConsed").field(self.inner()).finish()
192 }
193}
194
195impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
196 fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
197 v.visit(self.inner())
198 }
199}
200impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
202where
203 T: HashConsable,
204 V: for<'a> VisitMut<'a, T>,
205{
206 fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
207 self.with_inner_mut(|inner| v.visit(inner))
208 }
209}
210
211pub use serialize::{HashConsDedupSerializer, HashConsSerializerState};
218mod serialize {
219 use indexmap::IndexMap as SeqHashMap;
220 use serde::{Deserialize, Serialize};
221 use serde_state::{DeserializeState, SerializeState};
222 use std::any::type_name;
223 use std::cell::RefCell;
224 use std::collections::HashSet;
225
226 use super::{HashConsId, HashConsable, HashConsed};
227 use crate::common::type_map::{Mappable, Mapper, TypeMap};
228
229 pub trait HashConsSerializerState: Sized {
230 fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool>;
233 fn record_deserialized<T: Mappable>(&self, id: HashConsId, value: HashConsed<T>);
235 fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>>;
237 }
238
239 impl HashConsSerializerState for () {
240 fn record_serialized<T: Mappable>(&self, _id: HashConsId) -> Option<bool> {
241 None
242 }
243 fn record_deserialized<T: Mappable>(&self, _id: HashConsId, _value: HashConsed<T>) {}
244 fn get_deserialized_val<T: Mappable>(&self, _id: HashConsId) -> Option<HashConsed<T>> {
245 None
246 }
247 }
248
249 struct SerializeTableMapper;
250 impl Mapper for SerializeTableMapper {
251 type Value<T: Mappable> = HashSet<HashConsId>;
252 }
253 struct DeserializeTableMapper;
254 impl Mapper for DeserializeTableMapper {
255 type Value<T: Mappable> = SeqHashMap<HashConsId, HashConsed<T>>;
256 }
257 #[derive(Default)]
258 pub struct HashConsDedupSerializer {
259 ser: RefCell<TypeMap<SerializeTableMapper>>,
261 de: RefCell<TypeMap<DeserializeTableMapper>>,
263 }
264 impl HashConsSerializerState for HashConsDedupSerializer {
265 fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool> {
266 Some(self.ser.borrow_mut().or_default::<T>().insert(id))
267 }
268 fn record_deserialized<T: Mappable>(&self, id: HashConsId, val: HashConsed<T>) {
269 self.de.borrow_mut().or_default::<T>().insert(id, val);
270 }
271 fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>> {
272 self.de
273 .borrow()
274 .get::<T>()
275 .and_then(|map| map.get(&id))
276 .cloned()
277 }
278 }
279
280 #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
282 #[serde_state(state_implements = HashConsSerializerState)]
283 enum SerRepr<T> {
284 HashConsedValue(#[serde_state(stateless)] HashConsId, T),
288 #[serde_state(stateless)]
291 Deduplicated(HashConsId),
292 Untagged(T),
294 }
295
296 impl<T> Serialize for HashConsed<T>
297 where
298 T: Serialize + HashConsable,
299 {
300 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
301 where
302 S: serde::Serializer,
303 {
304 SerRepr::Untagged(self.inner()).serialize(serializer)
305 }
306 }
307 impl<T, State> SerializeState<State> for HashConsed<T>
310 where
311 T: SerializeState<State> + HashConsable,
312 State: HashConsSerializerState,
313 {
314 fn serialize_state<S>(&self, state: &State, serializer: S) -> Result<S::Ok, S::Error>
315 where
316 S: serde::Serializer,
317 {
318 let hash_cons_id = self.id();
319 let repr = match state.record_serialized::<T>(hash_cons_id) {
320 Some(true) => SerRepr::HashConsedValue(hash_cons_id, self.inner()),
321 Some(false) => SerRepr::Deduplicated(hash_cons_id),
322 None => SerRepr::Untagged(self.inner()),
323 };
324 repr.serialize_state(state, serializer)
325 }
326 }
327
328 impl<'de, T> Deserialize<'de> for HashConsed<T>
329 where
330 T: Deserialize<'de> + HashConsable,
331 {
332 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
333 where
334 D: serde::Deserializer<'de>,
335 {
336 use serde::de::Error;
337 let repr: SerRepr<T> = SerRepr::deserialize(deserializer)?;
338 match repr {
339 SerRepr::HashConsedValue { .. } | SerRepr::Deduplicated { .. } => {
340 let msg = format!(
341 "trying to deserialize a deduplicated value using serde's `{ty}::deserialize` method. \
342 This won't work, use serde_state's \
343 `{ty}::deserialize_state(&HashConsDedupSerializer::default(), _)` instead",
344 ty = type_name::<T>(),
345 );
346 Err(D::Error::custom(msg))
347 }
348 SerRepr::Untagged(val) => Ok(HashConsed::new(val)),
349 }
350 }
351 }
352 impl<'de, T, State> DeserializeState<'de, State> for HashConsed<T>
353 where
354 T: DeserializeState<'de, State> + HashConsable,
355 State: HashConsSerializerState,
356 {
357 fn deserialize_state<D>(state: &State, deserializer: D) -> Result<Self, D::Error>
358 where
359 D: serde::Deserializer<'de>,
360 {
361 use serde::de::Error;
362 let repr: SerRepr<T> = SerRepr::deserialize_state(state, deserializer)?;
363 Ok(match repr {
364 SerRepr::HashConsedValue(hash_cons_id, value) => {
365 let val = HashConsed::new(value);
366 state.record_deserialized(hash_cons_id, val.clone());
367 val
368 }
369 SerRepr::Deduplicated(hash_cons_id) => {
370 state.get_deserialized_val(hash_cons_id).ok_or_else(|| {
371 let msg = format!(
372 "can't deserialize deduplicated value of type {}; \
373 were you careful with managing the deduplication state?",
374 type_name::<T>()
375 );
376 D::Error::custom(msg)
377 })?
378 }
379 SerRepr::Untagged(val) => HashConsed::new(val),
380 })
381 }
382 }
383}
384
385#[test]
386fn test_hash_cons() {
387 let x = HashConsed::new(42u32);
388 let y = HashConsed::new(42u32);
389 assert_eq!(x, y);
390 let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
392 assert_eq!(x, z);
393}
394
395#[test]
396fn test_hash_cons_concurrent() {
397 use itertools::Itertools;
398 let handles = (0..10)
399 .map(|_| std::thread::spawn(|| std::hint::black_box(HashConsed::new(42u32))))
400 .collect_vec();
401 let values = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
402 assert!(values.iter().all_equal())
403}
404
405#[test]
406fn test_hash_cons_dedup() {
407 use serde_state::{DeserializeState, SerializeState};
408 type Ty = HashConsed<TyKind>;
409 #[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeState, DeserializeState)]
410 #[serde_state(state = HashConsDedupSerializer)]
411 enum TyKind {
412 Bool,
413 Pair(Ty, Ty),
414 }
415
416 let bool1 = HashConsed::new(TyKind::Bool);
418 let bool2 = HashConsed::new(TyKind::Bool);
419 let pair = HashConsed::new(TyKind::Pair(bool1.clone(), bool2));
420 let triple = HashConsed::new(TyKind::Pair(bool1, pair));
421
422 let state = HashConsDedupSerializer::default();
423 let json_val = triple
424 .serialize_state(&state, serde_json::value::Serializer)
425 .unwrap();
426 let state = HashConsDedupSerializer::default();
427 let round_tripped = Ty::deserialize_state(&state, json_val).unwrap();
428
429 assert_eq!(triple, round_tripped);
430}