1use derive_generic_visitor::{Drive, DriveMut, Visit, VisitMut};
2use serde::{Deserialize, Serialize};
3use std::hash::Hash;
4use std::ops::{ControlFlow, Deref};
5use std::sync::Arc;
6
7use crate::common::hash_by_addr::HashByAddr;
8use crate::common::type_map::Mappable;
9
10#[derive(PartialEq, Eq, Hash)]
16pub struct HashConsed<T>(HashByAddr<Arc<T>>);
17
18impl<T> Clone for HashConsed<T> {
19 fn clone(&self) -> Self {
20 Self(self.0.clone())
21 }
22}
23
24impl<T> HashConsed<T> {
25 pub fn inner(&self) -> &T {
26 self.0.0.as_ref()
27 }
28}
29
30pub trait HashConsable = Hash + PartialEq + Eq + Clone + Mappable;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct HashConsId(usize);
35
36mod intern_table {
42 use indexmap::IndexSet;
43 use std::sync::{Arc, LazyLock, RwLock};
44
45 use super::{HashConsId, HashConsable, HashConsed};
46 use crate::common::hash_by_addr::HashByAddr;
47 use crate::common::type_map::{Mappable, Mapper, TypeMap};
48
49 struct InternMapper;
56 impl Mapper for InternMapper {
57 type Value<T: Mappable> = IndexSet<Arc<T>>;
58 }
59 static INTERNED: LazyLock<RwLock<TypeMap<InternMapper>>> = LazyLock::new(|| Default::default());
60
61 pub fn intern<T: HashConsable>(inner: T) -> HashConsed<T> {
62 #[expect(irrefutable_let_patterns)] let arc = if let read_guard = INTERNED.read().unwrap()
65 && let Some(set) = read_guard.get::<T>()
66 && let Some(arc) = set.get(&inner)
67 {
68 arc.clone()
69 } else {
70 let mut write_guard = INTERNED.write().unwrap();
72 let set: &mut IndexSet<Arc<T>> = write_guard.or_default::<T>();
73 if let Some(arc) = set.get(&inner) {
74 arc.clone()
75 } else {
76 let arc: Arc<T> = Arc::new(inner);
77 set.insert(arc.clone());
78 arc
79 }
80 };
81 HashConsed(HashByAddr(arc))
82 }
83
84 pub fn id<T: HashConsable>(x: &HashConsed<T>) -> HashConsId {
87 HashConsId(
90 (*INTERNED.read().unwrap())
91 .get::<T>()
92 .unwrap()
93 .get_index_of(&x.0.0)
94 .unwrap(),
95 )
96 }
97}
98
99impl<T> HashConsed<T>
100where
101 T: HashConsable,
102{
103 pub fn new(inner: T) -> Self {
106 intern_table::intern(inner)
107 }
108
109 pub fn with_inner_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
111 let mut value = self.inner().clone();
113 let ret = f(&mut value);
114 *self = Self::new(value);
116 ret
117 }
118
119 pub fn id(&self) -> HashConsId {
120 intern_table::id(self)
121 }
122}
123
124impl<T> Deref for HashConsed<T> {
125 type Target = T;
126 fn deref(&self) -> &Self::Target {
127 self.inner()
128 }
129}
130
131impl<T: std::fmt::Debug> std::fmt::Debug for HashConsed<T> {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_tuple("HashConsed").field(self.inner()).finish()
135 }
136}
137
138impl<'s, T, V: Visit<'s, T>> Drive<'s, V> for HashConsed<T> {
139 fn drive_inner(&'s self, v: &mut V) -> ControlFlow<V::Break> {
140 v.visit(self.inner())
141 }
142}
143impl<'s, T, V> DriveMut<'s, V> for HashConsed<T>
145where
146 T: HashConsable,
147 V: for<'a> VisitMut<'a, T>,
148{
149 fn drive_inner_mut(&'s mut self, v: &mut V) -> ControlFlow<V::Break> {
150 self.with_inner_mut(|inner| v.visit(inner))
151 }
152}
153
154pub use serialize::{HashConsDedupSerializer, HashConsSerializerState};
161mod serialize {
162 use indexmap::IndexMap;
163 use serde::{Deserialize, Serialize};
164 use serde_state::{DeserializeState, SerializeState};
165 use std::any::type_name;
166 use std::cell::RefCell;
167 use std::collections::HashSet;
168
169 use super::{HashConsId, HashConsable, HashConsed};
170 use crate::common::type_map::{Mappable, Mapper, TypeMap};
171
172 pub trait HashConsSerializerState: Sized {
173 fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool>;
176 fn record_deserialized<T: Mappable>(&self, id: HashConsId, value: HashConsed<T>);
178 fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>>;
180 }
181
182 impl HashConsSerializerState for () {
183 fn record_serialized<T: Mappable>(&self, _id: HashConsId) -> Option<bool> {
184 None
185 }
186 fn record_deserialized<T: Mappable>(&self, _id: HashConsId, _value: HashConsed<T>) {}
187 fn get_deserialized_val<T: Mappable>(&self, _id: HashConsId) -> Option<HashConsed<T>> {
188 None
189 }
190 }
191
192 struct SerializeTableMapper;
193 impl Mapper for SerializeTableMapper {
194 type Value<T: Mappable> = HashSet<HashConsId>;
195 }
196 struct DeserializeTableMapper;
197 impl Mapper for DeserializeTableMapper {
198 type Value<T: Mappable> = IndexMap<HashConsId, HashConsed<T>>;
199 }
200 #[derive(Default)]
201 pub struct HashConsDedupSerializer {
202 ser: RefCell<TypeMap<SerializeTableMapper>>,
204 de: RefCell<TypeMap<DeserializeTableMapper>>,
206 }
207 impl HashConsSerializerState for HashConsDedupSerializer {
208 fn record_serialized<T: Mappable>(&self, id: HashConsId) -> Option<bool> {
209 Some(self.ser.borrow_mut().or_default::<T>().insert(id))
210 }
211 fn record_deserialized<T: Mappable>(&self, id: HashConsId, val: HashConsed<T>) {
212 self.de.borrow_mut().or_default::<T>().insert(id, val);
213 }
214 fn get_deserialized_val<T: Mappable>(&self, id: HashConsId) -> Option<HashConsed<T>> {
215 self.de
216 .borrow()
217 .get::<T>()
218 .and_then(|map| map.get(&id))
219 .cloned()
220 }
221 }
222
223 #[derive(Serialize, Deserialize, SerializeState, DeserializeState)]
225 #[serde_state(state_implements = HashConsSerializerState)]
226 enum SerRepr<T> {
227 HashConsedValue(#[serde_state(stateless)] HashConsId, T),
231 #[serde_state(stateless)]
234 Deduplicated(HashConsId),
235 Untagged(T),
237 }
238
239 impl<T> Serialize for HashConsed<T>
240 where
241 T: Serialize + HashConsable,
242 {
243 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
244 where
245 S: serde::Serializer,
246 {
247 SerRepr::Untagged(self.inner()).serialize(serializer)
248 }
249 }
250 impl<T, State> SerializeState<State> for HashConsed<T>
253 where
254 T: SerializeState<State> + HashConsable,
255 State: HashConsSerializerState,
256 {
257 fn serialize_state<S>(&self, state: &State, serializer: S) -> Result<S::Ok, S::Error>
258 where
259 S: serde::Serializer,
260 {
261 let hash_cons_id = self.id();
262 let repr = match state.record_serialized::<T>(hash_cons_id) {
263 Some(true) => SerRepr::HashConsedValue(hash_cons_id, self.inner()),
264 Some(false) => SerRepr::Deduplicated(hash_cons_id),
265 None => SerRepr::Untagged(self.inner()),
266 };
267 repr.serialize_state(state, serializer)
268 }
269 }
270
271 impl<'de, T> Deserialize<'de> for HashConsed<T>
272 where
273 T: Deserialize<'de> + HashConsable,
274 {
275 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
276 where
277 D: serde::Deserializer<'de>,
278 {
279 use serde::de::Error;
280 let repr: SerRepr<T> = SerRepr::deserialize(deserializer)?;
281 match repr {
282 SerRepr::HashConsedValue { .. } | SerRepr::Deduplicated { .. } => {
283 let msg = format!(
284 "trying to deserialize a deduplicated value using serde's `{ty}::deserialize` method. \
285 This won't work, use serde_state's \
286 `{ty}::deserialize_state(&HashConsDedupSerializer::default(), _)` instead",
287 ty = type_name::<T>(),
288 );
289 Err(D::Error::custom(msg))
290 }
291 SerRepr::Untagged(val) => Ok(HashConsed::new(val)),
292 }
293 }
294 }
295 impl<'de, T, State> DeserializeState<'de, State> for HashConsed<T>
296 where
297 T: DeserializeState<'de, State> + HashConsable,
298 State: HashConsSerializerState,
299 {
300 fn deserialize_state<D>(state: &State, deserializer: D) -> Result<Self, D::Error>
301 where
302 D: serde::Deserializer<'de>,
303 {
304 use serde::de::Error;
305 let repr: SerRepr<T> = SerRepr::deserialize_state(state, deserializer)?;
306 Ok(match repr {
307 SerRepr::HashConsedValue(hash_cons_id, value) => {
308 let val = HashConsed::new(value);
309 state.record_deserialized(hash_cons_id, val.clone());
310 val
311 }
312 SerRepr::Deduplicated(hash_cons_id) => {
313 state.get_deserialized_val(hash_cons_id).ok_or_else(|| {
314 let msg = format!(
315 "can't deserialize deduplicated value of type {}; \
316 were you careful with managing the deduplication state?",
317 type_name::<T>()
318 );
319 D::Error::custom(msg)
320 })?
321 }
322 SerRepr::Untagged(val) => HashConsed::new(val),
323 })
324 }
325 }
326}
327
328#[test]
329fn test_hash_cons() {
330 let x = HashConsed::new(42u32);
331 let y = HashConsed::new(42u32);
332 assert_eq!(x, y);
333 let z = serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap();
335 assert_eq!(x, z);
336}
337
338#[test]
339fn test_hash_cons_concurrent() {
340 use itertools::Itertools;
341 let handles = (0..10)
342 .into_iter()
343 .map(|_| std::thread::spawn(|| std::hint::black_box(HashConsed::new(42u32))))
344 .collect_vec();
345 let values = handles.into_iter().map(|h| h.join().unwrap()).collect_vec();
346 assert!(values.iter().all_equal())
347}
348
349#[test]
350fn test_hash_cons_dedup() {
351 use serde_state::{DeserializeState, SerializeState};
352 type Ty = HashConsed<TyKind>;
353 #[derive(Debug, Clone, PartialEq, Eq, Hash, SerializeState, DeserializeState)]
354 #[serde_state(state = HashConsDedupSerializer)]
355 enum TyKind {
356 Bool,
357 Pair(Ty, Ty),
358 }
359
360 let bool1 = HashConsed::new(TyKind::Bool);
362 let bool2 = HashConsed::new(TyKind::Bool);
363 let pair = HashConsed::new(TyKind::Pair(bool1.clone(), bool2));
364 let triple = HashConsed::new(TyKind::Pair(bool1, pair));
365
366 let state = HashConsDedupSerializer::default();
367 let json_val = triple
368 .serialize_state(&state, serde_json::value::Serializer)
369 .unwrap();
370 let state = HashConsDedupSerializer::default();
371 let round_tripped = Ty::deserialize_state(&state, json_val).unwrap();
372
373 assert_eq!(triple, round_tripped);
374}