Skip to main content

charon_driver/hax/utils/
type_map.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    marker::PhantomData,
5};
6
7pub trait TypeMappable = Any + Send + Sync;
8
9/// Defines a mapping from types to types.
10pub trait TypeMapper {
11    type Value<T: TypeMappable>: TypeMappable;
12}
13
14/// A map that maps types to values in a generic manner: we store for each type `T` a value of
15/// type `M::Value<T>`.
16pub struct TypeMap<M> {
17    data: HashMap<TypeId, Box<dyn TypeMappable>>,
18    phantom: PhantomData<M>,
19}
20
21impl<M: TypeMapper> TypeMap<M> {
22    pub fn get<T: TypeMappable>(&self) -> Option<&M::Value<T>> {
23        self.data
24            .get(&TypeId::of::<T>())
25            // We must be careful to not accidentally cast the box itself as `dyn Any`.
26            .map(|val: &Box<dyn TypeMappable>| &**val)
27            .and_then(|val: &dyn TypeMappable| (val as &dyn Any).downcast_ref())
28    }
29
30    pub fn get_mut<T: TypeMappable>(&mut self) -> Option<&mut M::Value<T>> {
31        self.data
32            .get_mut(&TypeId::of::<T>())
33            // We must be careful to not accidentally cast the box itself as `dyn Any`.
34            .map(|val: &mut Box<dyn TypeMappable>| &mut **val)
35            .and_then(|val: &mut dyn TypeMappable| (val as &mut dyn Any).downcast_mut())
36    }
37    pub fn or_default<T: TypeMappable>(&mut self) -> &mut M::Value<T>
38    where
39        M::Value<T>: Default,
40    {
41        if self.get::<T>().is_none() {
42            self.insert::<T>(Default::default());
43        }
44        self.get_mut().unwrap()
45    }
46
47    pub fn insert<T: TypeMappable>(&mut self, val: M::Value<T>) -> Option<Box<M::Value<T>>> {
48        self.data
49            .insert(TypeId::of::<T>(), Box::new(val))
50            .and_then(|val: Box<dyn TypeMappable>| (val as Box<dyn Any>).downcast().ok())
51    }
52}
53
54impl<M> Default for TypeMap<M> {
55    fn default() -> Self {
56        Self {
57            data: Default::default(),
58            phantom: Default::default(),
59        }
60    }
61}