1use crate::ast::*;
3use serde::{Deserialize, Serialize, Serializer};
4
5#[derive(Debug, Clone)]
6pub enum ScalarError {
7 IncorrectSign,
9 OutOfBounds,
11}
12pub type ScalarResult<T> = std::result::Result<T, ScalarError>;
14
15impl ScalarValue {
16 pub fn get_integer_ty(&self) -> IntegerTy {
17 match self {
18 ScalarValue::Isize(_) => IntegerTy::Isize,
19 ScalarValue::I8(_) => IntegerTy::I8,
20 ScalarValue::I16(_) => IntegerTy::I16,
21 ScalarValue::I32(_) => IntegerTy::I32,
22 ScalarValue::I64(_) => IntegerTy::I64,
23 ScalarValue::I128(_) => IntegerTy::I128,
24 ScalarValue::Usize(_) => IntegerTy::Usize,
25 ScalarValue::U8(_) => IntegerTy::U8,
26 ScalarValue::U16(_) => IntegerTy::U16,
27 ScalarValue::U32(_) => IntegerTy::U32,
28 ScalarValue::U64(_) => IntegerTy::U64,
29 ScalarValue::U128(_) => IntegerTy::U128,
30 }
31 }
32
33 pub fn is_int(&self) -> bool {
34 matches!(
35 self,
36 ScalarValue::Isize(_)
37 | ScalarValue::I8(_)
38 | ScalarValue::I16(_)
39 | ScalarValue::I32(_)
40 | ScalarValue::I64(_)
41 | ScalarValue::I128(_)
42 )
43 }
44
45 pub fn is_uint(&self) -> bool {
46 matches!(
47 self,
48 ScalarValue::Usize(_)
49 | ScalarValue::U8(_)
50 | ScalarValue::U16(_)
51 | ScalarValue::U32(_)
52 | ScalarValue::U64(_)
53 | ScalarValue::U128(_)
54 )
55 }
56
57 pub fn as_uint(&self) -> ScalarResult<u128> {
61 match self {
62 ScalarValue::Usize(v) => Ok(*v as u128),
63 ScalarValue::U8(v) => Ok(*v as u128),
64 ScalarValue::U16(v) => Ok(*v as u128),
65 ScalarValue::U32(v) => Ok(*v as u128),
66 ScalarValue::U64(v) => Ok(*v as u128),
67 ScalarValue::U128(v) => Ok(*v),
68 _ => Err(ScalarError::IncorrectSign),
69 }
70 }
71
72 pub fn uint_is_in_bounds(ty: IntegerTy, v: u128) -> bool {
73 match ty {
74 IntegerTy::Usize => v <= (usize::MAX as u128),
75 IntegerTy::U8 => v <= (u8::MAX as u128),
76 IntegerTy::U16 => v <= (u16::MAX as u128),
77 IntegerTy::U32 => v <= (u32::MAX as u128),
78 IntegerTy::U64 => v <= (u64::MAX as u128),
79 IntegerTy::U128 => true,
80 _ => false,
81 }
82 }
83
84 pub fn from_unchecked_uint(ty: IntegerTy, v: u128) -> ScalarValue {
85 match ty {
86 IntegerTy::Usize => ScalarValue::Usize(v as u64),
87 IntegerTy::U8 => ScalarValue::U8(v as u8),
88 IntegerTy::U16 => ScalarValue::U16(v as u16),
89 IntegerTy::U32 => ScalarValue::U32(v as u32),
90 IntegerTy::U64 => ScalarValue::U64(v as u64),
91 IntegerTy::U128 => ScalarValue::U128(v),
92 _ => panic!("Expected an unsigned integer kind"),
93 }
94 }
95
96 pub fn from_uint(ty: IntegerTy, v: u128) -> ScalarResult<ScalarValue> {
97 if !ScalarValue::uint_is_in_bounds(ty, v) {
98 trace!("Not in bounds for {:?}: {}", ty, v);
99 Err(ScalarError::OutOfBounds)
100 } else {
101 Ok(ScalarValue::from_unchecked_uint(ty, v))
102 }
103 }
104
105 pub fn as_int(&self) -> ScalarResult<i128> {
109 match self {
110 ScalarValue::Isize(v) => Ok(*v as i128),
111 ScalarValue::I8(v) => Ok(*v as i128),
112 ScalarValue::I16(v) => Ok(*v as i128),
113 ScalarValue::I32(v) => Ok(*v as i128),
114 ScalarValue::I64(v) => Ok(*v as i128),
115 ScalarValue::I128(v) => Ok(*v),
116 _ => Err(ScalarError::IncorrectSign),
117 }
118 }
119
120 pub fn int_is_in_bounds(ty: IntegerTy, v: i128) -> bool {
121 match ty {
122 IntegerTy::Isize => v >= (isize::MIN as i128) && v <= (isize::MAX as i128),
123 IntegerTy::I8 => v >= (i8::MIN as i128) && v <= (i8::MAX as i128),
124 IntegerTy::I16 => v >= (i16::MIN as i128) && v <= (i16::MAX as i128),
125 IntegerTy::I32 => v >= (i32::MIN as i128) && v <= (i32::MAX as i128),
126 IntegerTy::I64 => v >= (i64::MIN as i128) && v <= (i64::MAX as i128),
127 IntegerTy::I128 => true,
128 _ => false,
129 }
130 }
131
132 pub fn from_unchecked_int(ty: IntegerTy, v: i128) -> ScalarValue {
133 match ty {
134 IntegerTy::Isize => ScalarValue::Isize(v as i64),
135 IntegerTy::I8 => ScalarValue::I8(v as i8),
136 IntegerTy::I16 => ScalarValue::I16(v as i16),
137 IntegerTy::I32 => ScalarValue::I32(v as i32),
138 IntegerTy::I64 => ScalarValue::I64(v as i64),
139 IntegerTy::I128 => ScalarValue::I128(v),
140 _ => panic!("Expected a signed integer kind"),
141 }
142 }
143
144 pub fn from_le_bytes(ty: IntegerTy, b: [u8; 16]) -> ScalarValue {
145 match ty {
146 IntegerTy::Isize => {
147 let b: [u8; 8] = b[0..8].try_into().unwrap();
148 ScalarValue::Isize(i64::from_le_bytes(b))
149 }
150 IntegerTy::I8 => {
151 let b: [u8; 1] = b[0..1].try_into().unwrap();
152 ScalarValue::I8(i8::from_le_bytes(b))
153 }
154 IntegerTy::I16 => {
155 let b: [u8; 2] = b[0..2].try_into().unwrap();
156 ScalarValue::I16(i16::from_le_bytes(b))
157 }
158 IntegerTy::I32 => {
159 let b: [u8; 4] = b[0..4].try_into().unwrap();
160 ScalarValue::I32(i32::from_le_bytes(b))
161 }
162 IntegerTy::I64 => {
163 let b: [u8; 8] = b[0..8].try_into().unwrap();
164 ScalarValue::I64(i64::from_le_bytes(b))
165 }
166 IntegerTy::I128 => {
167 let b: [u8; 16] = b[0..16].try_into().unwrap();
168 ScalarValue::I128(i128::from_le_bytes(b))
169 }
170 IntegerTy::Usize => {
171 let b: [u8; 8] = b[0..8].try_into().unwrap();
172 ScalarValue::Usize(u64::from_le_bytes(b))
173 }
174 IntegerTy::U8 => {
175 let b: [u8; 1] = b[0..1].try_into().unwrap();
176 ScalarValue::U8(u8::from_le_bytes(b))
177 }
178 IntegerTy::U16 => {
179 let b: [u8; 2] = b[0..2].try_into().unwrap();
180 ScalarValue::U16(u16::from_le_bytes(b))
181 }
182 IntegerTy::U32 => {
183 let b: [u8; 4] = b[0..4].try_into().unwrap();
184 ScalarValue::U32(u32::from_le_bytes(b))
185 }
186 IntegerTy::U64 => {
187 let b: [u8; 8] = b[0..8].try_into().unwrap();
188 ScalarValue::U64(u64::from_le_bytes(b))
189 }
190 IntegerTy::U128 => {
191 let b: [u8; 16] = b[0..16].try_into().unwrap();
192 ScalarValue::U128(u128::from_le_bytes(b))
193 }
194 }
195 }
196
197 pub fn to_bits(&self) -> u128 {
199 match *self {
200 ScalarValue::Usize(v) => v as u128,
201 ScalarValue::U8(v) => v as u128,
202 ScalarValue::U16(v) => v as u128,
203 ScalarValue::U32(v) => v as u128,
204 ScalarValue::U64(v) => v as u128,
205 ScalarValue::U128(v) => v,
206 ScalarValue::Isize(v) => v as usize as u128,
207 ScalarValue::I8(v) => v as u8 as u128,
208 ScalarValue::I16(v) => v as u16 as u128,
209 ScalarValue::I32(v) => v as u32 as u128,
210 ScalarValue::I64(v) => v as u64 as u128,
211 ScalarValue::I128(v) => v as u128,
212 }
213 }
214
215 pub fn from_bits(ty: IntegerTy, bits: u128) -> Self {
216 Self::from_le_bytes(ty, bits.to_le_bytes())
217 }
218
219 pub fn from_int(ty: IntegerTy, v: i128) -> ScalarResult<ScalarValue> {
223 if !ScalarValue::int_is_in_bounds(ty, v) {
224 Err(ScalarError::OutOfBounds)
225 } else {
226 Ok(ScalarValue::from_unchecked_int(ty, v))
227 }
228 }
229
230 pub fn to_constant(self) -> ConstantExpr {
231 ConstantExpr {
232 value: RawConstantExpr::Literal(Literal::Scalar(self)),
233 ty: TyKind::Literal(LiteralTy::Integer(self.get_integer_ty())).into_ty(),
234 }
235 }
236}
237
238impl Serialize for ScalarValue {
240 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
241 where
242 S: Serializer,
243 {
244 let enum_name = "ScalarValue";
245 let variant_name = self.variant_name();
246 let (variant_index, _variant_arity) = self.variant_index_arity();
247 let v = match self {
248 ScalarValue::Isize(i) => i.to_string(),
249 ScalarValue::I8(i) => i.to_string(),
250 ScalarValue::I16(i) => i.to_string(),
251 ScalarValue::I32(i) => i.to_string(),
252 ScalarValue::I64(i) => i.to_string(),
253 ScalarValue::I128(i) => i.to_string(),
254 ScalarValue::Usize(i) => i.to_string(),
255 ScalarValue::U8(i) => i.to_string(),
256 ScalarValue::U16(i) => i.to_string(),
257 ScalarValue::U32(i) => i.to_string(),
258 ScalarValue::U64(i) => i.to_string(),
259 ScalarValue::U128(i) => i.to_string(),
260 };
261 serializer.serialize_newtype_variant(enum_name, variant_index, variant_name, &v)
262 }
263}
264
265impl<'de> Deserialize<'de> for ScalarValue {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: serde::Deserializer<'de>,
269 {
270 struct Visitor;
271 impl<'de> serde::de::Visitor<'de> for Visitor {
272 type Value = ScalarValue;
273 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
274 write!(f, "ScalarValue")
275 }
276 fn visit_map<A: serde::de::MapAccess<'de>>(
277 self,
278 mut map: A,
279 ) -> Result<Self::Value, A::Error> {
280 use serde::de::Error;
281 let (k, v): (String, String) = map.next_entry()?.expect("Malformed ScalarValue");
282 Ok(match k.as_str() {
283 "Isize" => ScalarValue::Isize(v.parse().unwrap()),
284 "I8" => ScalarValue::I8(v.parse().unwrap()),
285 "I16" => ScalarValue::I16(v.parse().unwrap()),
286 "I32" => ScalarValue::I32(v.parse().unwrap()),
287 "I64" => ScalarValue::I64(v.parse().unwrap()),
288 "I128" => ScalarValue::I128(v.parse().unwrap()),
289 "Usize" => ScalarValue::Usize(v.parse().unwrap()),
290 "U8" => ScalarValue::U8(v.parse().unwrap()),
291 "U16" => ScalarValue::U16(v.parse().unwrap()),
292 "U32" => ScalarValue::U32(v.parse().unwrap()),
293 "U64" => ScalarValue::U64(v.parse().unwrap()),
294 "U128" => ScalarValue::U128(v.parse().unwrap()),
295 _ => {
296 return Err(A::Error::custom(format!(
297 "{k} is not a valid type for a ScalarValue"
298 )))
299 }
300 })
301 }
302 }
303 deserializer.deserialize_map(Visitor)
304 }
305}