charon_lib/ast/
values_utils.rs

1//! Implementations for [crate::values]
2use crate::ast::*;
3use serde::{Deserialize, Serialize, Serializer};
4
5#[derive(Debug, Clone)]
6pub enum ScalarError {
7    /// Attempt to use a signed scalar as an unsigned scalar or vice-versa
8    IncorrectSign,
9    /// Out of bounds scalar
10    OutOfBounds,
11}
12/// Our redefinition of Result - we don't care much about the I/O part.
13pub type ScalarResult<T> = std::result::Result<T, ScalarError>;
14
15impl ScalarValue {
16    pub fn get_integer_ty(&self) -> IntegerTy {
17        match self {
18            ScalarValue::Isize(_) => IntegerTy::Isize,
19            ScalarValue::I8(_) => IntegerTy::I8,
20            ScalarValue::I16(_) => IntegerTy::I16,
21            ScalarValue::I32(_) => IntegerTy::I32,
22            ScalarValue::I64(_) => IntegerTy::I64,
23            ScalarValue::I128(_) => IntegerTy::I128,
24            ScalarValue::Usize(_) => IntegerTy::Usize,
25            ScalarValue::U8(_) => IntegerTy::U8,
26            ScalarValue::U16(_) => IntegerTy::U16,
27            ScalarValue::U32(_) => IntegerTy::U32,
28            ScalarValue::U64(_) => IntegerTy::U64,
29            ScalarValue::U128(_) => IntegerTy::U128,
30        }
31    }
32
33    pub fn is_int(&self) -> bool {
34        matches!(
35            self,
36            ScalarValue::Isize(_)
37                | ScalarValue::I8(_)
38                | ScalarValue::I16(_)
39                | ScalarValue::I32(_)
40                | ScalarValue::I64(_)
41                | ScalarValue::I128(_)
42        )
43    }
44
45    pub fn is_uint(&self) -> bool {
46        matches!(
47            self,
48            ScalarValue::Usize(_)
49                | ScalarValue::U8(_)
50                | ScalarValue::U16(_)
51                | ScalarValue::U32(_)
52                | ScalarValue::U64(_)
53                | ScalarValue::U128(_)
54        )
55    }
56
57    /// When computing the result of binary operations, we convert the values
58    /// to u128 then back to the target type (while performing dynamic checks
59    /// of course).
60    pub fn as_uint(&self) -> ScalarResult<u128> {
61        match self {
62            ScalarValue::Usize(v) => Ok(*v as u128),
63            ScalarValue::U8(v) => Ok(*v as u128),
64            ScalarValue::U16(v) => Ok(*v as u128),
65            ScalarValue::U32(v) => Ok(*v as u128),
66            ScalarValue::U64(v) => Ok(*v as u128),
67            ScalarValue::U128(v) => Ok(*v),
68            _ => Err(ScalarError::IncorrectSign),
69        }
70    }
71
72    pub fn uint_is_in_bounds(ty: IntegerTy, v: u128) -> bool {
73        match ty {
74            IntegerTy::Usize => v <= (usize::MAX as u128),
75            IntegerTy::U8 => v <= (u8::MAX as u128),
76            IntegerTy::U16 => v <= (u16::MAX as u128),
77            IntegerTy::U32 => v <= (u32::MAX as u128),
78            IntegerTy::U64 => v <= (u64::MAX as u128),
79            IntegerTy::U128 => true,
80            _ => false,
81        }
82    }
83
84    pub fn from_unchecked_uint(ty: IntegerTy, v: u128) -> ScalarValue {
85        match ty {
86            IntegerTy::Usize => ScalarValue::Usize(v as u64),
87            IntegerTy::U8 => ScalarValue::U8(v as u8),
88            IntegerTy::U16 => ScalarValue::U16(v as u16),
89            IntegerTy::U32 => ScalarValue::U32(v as u32),
90            IntegerTy::U64 => ScalarValue::U64(v as u64),
91            IntegerTy::U128 => ScalarValue::U128(v),
92            _ => panic!("Expected an unsigned integer kind"),
93        }
94    }
95
96    pub fn from_uint(ty: IntegerTy, v: u128) -> ScalarResult<ScalarValue> {
97        if !ScalarValue::uint_is_in_bounds(ty, v) {
98            trace!("Not in bounds for {:?}: {}", ty, v);
99            Err(ScalarError::OutOfBounds)
100        } else {
101            Ok(ScalarValue::from_unchecked_uint(ty, v))
102        }
103    }
104
105    /// When computing the result of binary operations, we convert the values
106    /// to i128 then back to the target type (while performing dynamic checks
107    /// of course).
108    pub fn as_int(&self) -> ScalarResult<i128> {
109        match self {
110            ScalarValue::Isize(v) => Ok(*v as i128),
111            ScalarValue::I8(v) => Ok(*v as i128),
112            ScalarValue::I16(v) => Ok(*v as i128),
113            ScalarValue::I32(v) => Ok(*v as i128),
114            ScalarValue::I64(v) => Ok(*v as i128),
115            ScalarValue::I128(v) => Ok(*v),
116            _ => Err(ScalarError::IncorrectSign),
117        }
118    }
119
120    pub fn int_is_in_bounds(ty: IntegerTy, v: i128) -> bool {
121        match ty {
122            IntegerTy::Isize => v >= (isize::MIN as i128) && v <= (isize::MAX as i128),
123            IntegerTy::I8 => v >= (i8::MIN as i128) && v <= (i8::MAX as i128),
124            IntegerTy::I16 => v >= (i16::MIN as i128) && v <= (i16::MAX as i128),
125            IntegerTy::I32 => v >= (i32::MIN as i128) && v <= (i32::MAX as i128),
126            IntegerTy::I64 => v >= (i64::MIN as i128) && v <= (i64::MAX as i128),
127            IntegerTy::I128 => true,
128            _ => false,
129        }
130    }
131
132    pub fn from_unchecked_int(ty: IntegerTy, v: i128) -> ScalarValue {
133        match ty {
134            IntegerTy::Isize => ScalarValue::Isize(v as i64),
135            IntegerTy::I8 => ScalarValue::I8(v as i8),
136            IntegerTy::I16 => ScalarValue::I16(v as i16),
137            IntegerTy::I32 => ScalarValue::I32(v as i32),
138            IntegerTy::I64 => ScalarValue::I64(v as i64),
139            IntegerTy::I128 => ScalarValue::I128(v),
140            _ => panic!("Expected a signed integer kind"),
141        }
142    }
143
144    pub fn from_le_bytes(ty: IntegerTy, b: [u8; 16]) -> ScalarValue {
145        match ty {
146            IntegerTy::Isize => {
147                let b: [u8; 8] = b[0..8].try_into().unwrap();
148                ScalarValue::Isize(i64::from_le_bytes(b))
149            }
150            IntegerTy::I8 => {
151                let b: [u8; 1] = b[0..1].try_into().unwrap();
152                ScalarValue::I8(i8::from_le_bytes(b))
153            }
154            IntegerTy::I16 => {
155                let b: [u8; 2] = b[0..2].try_into().unwrap();
156                ScalarValue::I16(i16::from_le_bytes(b))
157            }
158            IntegerTy::I32 => {
159                let b: [u8; 4] = b[0..4].try_into().unwrap();
160                ScalarValue::I32(i32::from_le_bytes(b))
161            }
162            IntegerTy::I64 => {
163                let b: [u8; 8] = b[0..8].try_into().unwrap();
164                ScalarValue::I64(i64::from_le_bytes(b))
165            }
166            IntegerTy::I128 => {
167                let b: [u8; 16] = b[0..16].try_into().unwrap();
168                ScalarValue::I128(i128::from_le_bytes(b))
169            }
170            IntegerTy::Usize => {
171                let b: [u8; 8] = b[0..8].try_into().unwrap();
172                ScalarValue::Usize(u64::from_le_bytes(b))
173            }
174            IntegerTy::U8 => {
175                let b: [u8; 1] = b[0..1].try_into().unwrap();
176                ScalarValue::U8(u8::from_le_bytes(b))
177            }
178            IntegerTy::U16 => {
179                let b: [u8; 2] = b[0..2].try_into().unwrap();
180                ScalarValue::U16(u16::from_le_bytes(b))
181            }
182            IntegerTy::U32 => {
183                let b: [u8; 4] = b[0..4].try_into().unwrap();
184                ScalarValue::U32(u32::from_le_bytes(b))
185            }
186            IntegerTy::U64 => {
187                let b: [u8; 8] = b[0..8].try_into().unwrap();
188                ScalarValue::U64(u64::from_le_bytes(b))
189            }
190            IntegerTy::U128 => {
191                let b: [u8; 16] = b[0..16].try_into().unwrap();
192                ScalarValue::U128(u128::from_le_bytes(b))
193            }
194        }
195    }
196
197    /// Most integers are represented as `u128` by rustc. We must be careful not to sign-extend.
198    pub fn to_bits(&self) -> u128 {
199        match *self {
200            ScalarValue::Usize(v) => v as u128,
201            ScalarValue::U8(v) => v as u128,
202            ScalarValue::U16(v) => v as u128,
203            ScalarValue::U32(v) => v as u128,
204            ScalarValue::U64(v) => v as u128,
205            ScalarValue::U128(v) => v,
206            ScalarValue::Isize(v) => v as usize as u128,
207            ScalarValue::I8(v) => v as u8 as u128,
208            ScalarValue::I16(v) => v as u16 as u128,
209            ScalarValue::I32(v) => v as u32 as u128,
210            ScalarValue::I64(v) => v as u64 as u128,
211            ScalarValue::I128(v) => v as u128,
212        }
213    }
214
215    pub fn from_bits(ty: IntegerTy, bits: u128) -> Self {
216        Self::from_le_bytes(ty, bits.to_le_bytes())
217    }
218
219    /// **Warning**: most constants are stored as u128 by rustc. When converting
220    /// to i128, it is not correct to do `v as i128`, we must reinterpret the
221    /// bits (see [ScalarValue::from_le_bytes]).
222    pub fn from_int(ty: IntegerTy, v: i128) -> ScalarResult<ScalarValue> {
223        if !ScalarValue::int_is_in_bounds(ty, v) {
224            Err(ScalarError::OutOfBounds)
225        } else {
226            Ok(ScalarValue::from_unchecked_int(ty, v))
227        }
228    }
229
230    pub fn to_constant(self) -> ConstantExpr {
231        ConstantExpr {
232            value: RawConstantExpr::Literal(Literal::Scalar(self)),
233            ty: TyKind::Literal(LiteralTy::Integer(self.get_integer_ty())).into_ty(),
234        }
235    }
236}
237
238/// Custom serializer that stores integers as strings to avoid overflow.
239impl Serialize for ScalarValue {
240    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
241    where
242        S: Serializer,
243    {
244        let enum_name = "ScalarValue";
245        let variant_name = self.variant_name();
246        let (variant_index, _variant_arity) = self.variant_index_arity();
247        let v = match self {
248            ScalarValue::Isize(i) => i.to_string(),
249            ScalarValue::I8(i) => i.to_string(),
250            ScalarValue::I16(i) => i.to_string(),
251            ScalarValue::I32(i) => i.to_string(),
252            ScalarValue::I64(i) => i.to_string(),
253            ScalarValue::I128(i) => i.to_string(),
254            ScalarValue::Usize(i) => i.to_string(),
255            ScalarValue::U8(i) => i.to_string(),
256            ScalarValue::U16(i) => i.to_string(),
257            ScalarValue::U32(i) => i.to_string(),
258            ScalarValue::U64(i) => i.to_string(),
259            ScalarValue::U128(i) => i.to_string(),
260        };
261        serializer.serialize_newtype_variant(enum_name, variant_index, variant_name, &v)
262    }
263}
264
265impl<'de> Deserialize<'de> for ScalarValue {
266    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267    where
268        D: serde::Deserializer<'de>,
269    {
270        struct Visitor;
271        impl<'de> serde::de::Visitor<'de> for Visitor {
272            type Value = ScalarValue;
273            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
274                write!(f, "ScalarValue")
275            }
276            fn visit_map<A: serde::de::MapAccess<'de>>(
277                self,
278                mut map: A,
279            ) -> Result<Self::Value, A::Error> {
280                use serde::de::Error;
281                let (k, v): (String, String) = map.next_entry()?.expect("Malformed ScalarValue");
282                Ok(match k.as_str() {
283                    "Isize" => ScalarValue::Isize(v.parse().unwrap()),
284                    "I8" => ScalarValue::I8(v.parse().unwrap()),
285                    "I16" => ScalarValue::I16(v.parse().unwrap()),
286                    "I32" => ScalarValue::I32(v.parse().unwrap()),
287                    "I64" => ScalarValue::I64(v.parse().unwrap()),
288                    "I128" => ScalarValue::I128(v.parse().unwrap()),
289                    "Usize" => ScalarValue::Usize(v.parse().unwrap()),
290                    "U8" => ScalarValue::U8(v.parse().unwrap()),
291                    "U16" => ScalarValue::U16(v.parse().unwrap()),
292                    "U32" => ScalarValue::U32(v.parse().unwrap()),
293                    "U64" => ScalarValue::U64(v.parse().unwrap()),
294                    "U128" => ScalarValue::U128(v.parse().unwrap()),
295                    _ => {
296                        return Err(A::Error::custom(format!(
297                            "{k} is not a valid type for a ScalarValue"
298                        )))
299                    }
300                })
301            }
302        }
303        deserializer.deserialize_map(Visitor)
304    }
305}