1use crate::ast::*;
3
4#[derive(Debug, Clone)]
5pub enum ScalarError {
6 IncorrectSign,
8 OutOfBounds,
10 UnsupportedPtrSize,
11}
12pub type ScalarResult<T> = std::result::Result<T, ScalarError>;
14
15macro_rules! from_le_bytes {
16 ($m:ident, $b:ident, [$(($i_ty: ty, $i:ident, $s:ident, $n_ty:ty, $t:ty)),*]) => {
17 match $m {
18 $(
19 IntegerTy::$s(<$i_ty>::$i) => {
20 let n = size_of::<$n_ty>();
21 let b: [u8; _] = $b[0..n].try_into().unwrap();
22 ScalarValue::$s(<$i_ty>::$i, <$n_ty>::from_le_bytes(b) as $t)
23 }
24 )*
25 }
26 }
27}
28
29impl Literal {
30 pub fn char_from_le_bytes(bits: u128) -> Self {
31 let b: [u8; 4] = bits.to_le_bytes()[0..4].try_into().unwrap();
32 Literal::Char(std::char::from_u32(u32::from_le_bytes(b)).unwrap())
33 }
34
35 pub fn from_bits(lit_ty: &LiteralTy, bits: u128) -> Option<Self> {
36 match *lit_ty {
37 LiteralTy::Int(int_ty) => Some(Literal::Scalar(ScalarValue::from_bits(
38 IntegerTy::Signed(int_ty),
39 bits,
40 ))),
41 LiteralTy::UInt(uint_ty) => Some(Literal::Scalar(ScalarValue::from_bits(
42 IntegerTy::Unsigned(uint_ty),
43 bits,
44 ))),
45 LiteralTy::Char => Some(Literal::char_from_le_bytes(bits)),
46 _ => None,
47 }
48 }
49}
50
51impl ScalarValue {
52 fn ptr_size_max(ptr_size: ByteCount, signed: bool) -> ScalarResult<u128> {
53 match ptr_size {
54 2 => Ok(if signed {
55 i16::MAX as u128
56 } else {
57 u16::MAX as u128
58 }),
59 4 => Ok(if signed {
60 i32::MAX as u128
61 } else {
62 u32::MAX as u128
63 }),
64 8 => Ok(if signed {
65 i64::MAX as u128
66 } else {
67 u64::MAX as u128
68 }),
69 _ => Err(ScalarError::UnsupportedPtrSize),
70 }
71 }
72
73 fn ptr_size_min(ptr_size: ByteCount, signed: bool) -> ScalarResult<i128> {
74 match ptr_size {
75 2 => Ok(if signed {
76 i16::MIN as i128
77 } else {
78 u16::MIN as i128
79 }),
80 4 => Ok(if signed {
81 i32::MIN as i128
82 } else {
83 u32::MIN as i128
84 }),
85 8 => Ok(if signed {
86 i64::MIN as i128
87 } else {
88 u64::MIN as i128
89 }),
90 _ => Err(ScalarError::UnsupportedPtrSize),
91 }
92 }
93
94 pub fn get_integer_ty(&self) -> IntegerTy {
95 match self {
96 ScalarValue::Signed(ty, _) => IntegerTy::Signed(*ty),
97 ScalarValue::Unsigned(ty, _) => IntegerTy::Unsigned(*ty),
98 }
99 }
100
101 pub fn is_int(&self) -> bool {
102 matches!(self, ScalarValue::Signed(_, _))
103 }
104
105 pub fn is_uint(&self) -> bool {
106 matches!(self, ScalarValue::Unsigned(_, _))
107 }
108
109 pub fn as_uint(&self) -> ScalarResult<u128> {
113 match self {
114 ScalarValue::Unsigned(_, v) => Ok(*v),
115 _ => Err(ScalarError::IncorrectSign),
116 }
117 }
118
119 pub fn uint_is_in_bounds(ptr_size: ByteCount, ty: UIntTy, v: u128) -> bool {
120 match ty {
121 UIntTy::Usize => v <= Self::ptr_size_max(ptr_size, false).unwrap(),
122 UIntTy::U8 => v <= (u8::MAX as u128),
123 UIntTy::U16 => v <= (u16::MAX as u128),
124 UIntTy::U32 => v <= (u32::MAX as u128),
125 UIntTy::U64 => v <= (u64::MAX as u128),
126 UIntTy::U128 => true,
127 }
128 }
129
130 pub fn from_unchecked_uint(ty: UIntTy, v: u128) -> ScalarValue {
131 ScalarValue::Unsigned(ty, v)
132 }
133
134 pub fn from_uint(ptr_size: ByteCount, ty: UIntTy, v: u128) -> ScalarResult<ScalarValue> {
135 if !ScalarValue::uint_is_in_bounds(ptr_size, ty, v) {
136 trace!("Not in bounds for {:?}: {}", ty, v);
137 Err(ScalarError::OutOfBounds)
138 } else {
139 Ok(ScalarValue::from_unchecked_uint(ty, v))
140 }
141 }
142
143 pub fn as_int(&self) -> ScalarResult<i128> {
147 match self {
148 ScalarValue::Signed(_, v) => Ok(*v),
149 _ => Err(ScalarError::IncorrectSign),
150 }
151 }
152
153 pub fn int_is_in_bounds(ptr_size: ByteCount, ty: IntTy, v: i128) -> bool {
154 match ty {
155 IntTy::Isize => {
156 v >= Self::ptr_size_min(ptr_size, true).unwrap()
157 && v <= Self::ptr_size_max(ptr_size, true).unwrap() as i128
158 }
159 IntTy::I8 => v >= (i8::MIN as i128) && v <= (i8::MAX as i128),
160 IntTy::I16 => v >= (i16::MIN as i128) && v <= (i16::MAX as i128),
161 IntTy::I32 => v >= (i32::MIN as i128) && v <= (i32::MAX as i128),
162 IntTy::I64 => v >= (i64::MIN as i128) && v <= (i64::MAX as i128),
163 IntTy::I128 => true,
164 }
165 }
166
167 pub fn from_unchecked_int(ty: IntTy, v: i128) -> ScalarValue {
168 ScalarValue::Signed(ty, v)
169 }
170
171 pub fn to_bits(&self) -> u128 {
173 match *self {
174 ScalarValue::Unsigned(_, v) => v,
175 ScalarValue::Signed(_, v) => u128::from_le_bytes(v.to_le_bytes()),
176 }
177 }
178
179 pub fn from_le_bytes(ty: IntegerTy, bytes: [u8; 16]) -> Self {
184 from_le_bytes!(
185 ty,
186 bytes,
187 [
188 (IntTy, Isize, Signed, isize, i128),
189 (IntTy, I8, Signed, i8, i128),
190 (IntTy, I16, Signed, i16, i128),
191 (IntTy, I32, Signed, i32, i128),
192 (IntTy, I64, Signed, i64, i128),
193 (IntTy, I128, Signed, i128, i128),
194 (UIntTy, Usize, Unsigned, usize, u128),
195 (UIntTy, U8, Unsigned, u8, u128),
196 (UIntTy, U16, Unsigned, u16, u128),
197 (UIntTy, U32, Unsigned, u32, u128),
198 (UIntTy, U64, Unsigned, u64, u128),
199 (UIntTy, U128, Unsigned, u128, u128)
200 ]
201 )
202 }
203
204 pub fn from_bits(ty: IntegerTy, bits: u128) -> Self {
205 let bytes = bits.to_le_bytes();
206 Self::from_le_bytes(ty, bytes)
207 }
208
209 pub fn from_int(ptr_size: ByteCount, ty: IntTy, v: i128) -> ScalarResult<ScalarValue> {
213 if !ScalarValue::int_is_in_bounds(ptr_size, ty, v) {
214 Err(ScalarError::OutOfBounds)
215 } else {
216 Ok(ScalarValue::from_unchecked_int(ty, v))
217 }
218 }
219
220 pub fn to_constant(self) -> ConstantExpr {
221 let literal_ty = match self {
222 ScalarValue::Signed(int_ty, _) => LiteralTy::Int(int_ty),
223 ScalarValue::Unsigned(uint_ty, _) => LiteralTy::UInt(uint_ty),
224 };
225 ConstantExpr {
226 kind: ConstantExprKind::Literal(Literal::Scalar(self)),
227 ty: TyKind::Literal(literal_ty).into_ty(),
228 }
229 }
230}
231
232pub(crate) mod scalar_value_ser_de {
234 use std::{marker::PhantomData, str::FromStr};
235
236 use serde::de::{Deserializer, Error};
237
238 pub fn serialize<S, V>(val: &V, serializer: S) -> Result<S::Ok, S::Error>
239 where
240 S: serde::ser::Serializer,
241 V: ToString,
242 {
243 serializer.serialize_str(&val.to_string())
244 }
245
246 pub fn deserialize<'de, D, V>(deserializer: D) -> Result<V, D::Error>
247 where
248 D: Deserializer<'de>,
249 V: FromStr,
250 {
251 struct Visitor<V> {
252 _val: PhantomData<V>,
253 }
254 impl<'de, V> serde::de::Visitor<'de> for Visitor<V>
255 where
256 V: FromStr,
257 {
258 type Value = V;
259 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
260 write!(f, "ScalarValue value")
261 }
262 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
263 where
264 E: Error,
265 {
266 v.parse()
267 .map_err(|_| E::custom("Could not parse 128 bit integer!"))
268 }
269 }
270 deserializer.deserialize_str(Visitor { _val: PhantomData })
271 }
272}
273
274#[cfg(test)]
275mod test {
276 use super::*;
277
278 #[test]
279 fn test_big_endian_scalars() -> ScalarResult<()> {
280 let u128 = 0x12345678901234567890123456789012u128;
281 let le_bytes = u128.to_le_bytes();
282
283 let le_scalar = ScalarValue::from_le_bytes(IntegerTy::Unsigned(UIntTy::U128), le_bytes);
284 assert_eq!(le_scalar, ScalarValue::Unsigned(UIntTy::U128, u128));
285
286 let i64 = 0x1234567890123456i64;
287 let le_bytes = (i64 as i128).to_le_bytes();
288 let le_scalar = ScalarValue::from_le_bytes(IntegerTy::Signed(IntTy::I64), le_bytes);
289 assert_eq!(le_scalar, ScalarValue::Signed(IntTy::I64, i64 as i128));
290
291 Ok(())
292 }
293}