charon_lib/transform/
simplify_constants.rs

1//! The MIR constant expressions lead to a lot of duplication: there are
2//! for instance constant ADTs which duplicate the "regular" aggregated
3//! ADTs in the operands, constant references, etc. This reduces the number
4//! of cases to handle and eases the function translation in Aeneas.
5//!
6//! This pass removes all those occurrences so that only the
7//! `ConstantExpression::Literal`. It does so by introducing intermediate statements.
8//!
9//! A small remark about the intermediate statements we introduce for the globals:
10//! we do so because, when evaluating the code in "concrete" mode, it allows to
11//! handle the globals like function calls.
12
13use itertools::Itertools;
14use std::assert_matches::assert_matches;
15
16use crate::transform::TransformCtx;
17use crate::ullbc_ast::*;
18
19use super::ctx::UllbcPass;
20
21/// If the constant value is a constant ADT, push `Assign::Aggregate` statements
22/// to the vector of statements, that bind new variables to the ADT parts and
23/// the variable assigned to the complete ADT.
24///
25/// Goes fom e.g. `f(T::A(x, y))` to `let a = T::A(x, y); f(a)`.
26/// The function is recursively called on the aggregate fields (e.g. here x and y).
27fn transform_constant_expr(
28    span: &Span,
29    val: Box<ConstantExpr>,
30    new_var: &mut impl FnMut(Rvalue, Ty) -> Place,
31) -> Operand {
32    match val.value {
33        ConstantExprKind::Literal(_)
34        | ConstantExprKind::Var(_)
35        | ConstantExprKind::RawMemory(..)
36        | ConstantExprKind::TraitConst(..)
37        | ConstantExprKind::FnPtr(..)
38        | ConstantExprKind::Opaque(_) => {
39            // Nothing to do
40            // TODO: for trait const: might come from a top-level impl, so we might
41            // want to introduce an intermediate statement to be able to evaluate
42            // it as a function call, like for globals.
43            Operand::Const(val)
44        }
45        // Here we use a copy, rather than a move -- moving a global would leave it uninitialized,
46        // which would e.g. make the following code fail:
47        //     const GLOBAL: usize = 0;
48        //     let x = GLOBAL;
49        //     let y = GLOBAL; // if moving, at this point GLOBAL would be uninitialized
50        ConstantExprKind::Global(global_ref) => {
51            Operand::Copy(Place::new_global(global_ref, val.ty))
52        }
53        ConstantExprKind::PtrNoProvenance(ptr) => {
54            let usize_ty = TyKind::Literal(LiteralTy::UInt(UIntTy::Usize)).into_ty();
55            let ptr_usize = ConstantExprKind::Literal(Literal::Scalar(ScalarValue::Unsigned(
56                UIntTy::Usize,
57                ptr,
58            )));
59            let cast = UnOp::Cast(CastKind::RawPtr(usize_ty.clone(), val.ty.clone()));
60            let uvar = new_var(
61                Rvalue::UnaryOp(
62                    cast,
63                    Operand::Const(Box::new(ConstantExpr {
64                        value: ptr_usize,
65                        ty: usize_ty,
66                    })),
67                ),
68                val.ty,
69            );
70            Operand::Move(uvar)
71        }
72        ConstantExprKind::Ref(bval) => {
73            match bval.value {
74                ConstantExprKind::Global(global_ref) => Operand::Move(new_var(
75                    Rvalue::Ref(Place::new_global(global_ref, bval.ty), BorrowKind::Shared),
76                    val.ty,
77                )),
78                _ => {
79                    // Recurse on the borrowed value
80                    let bval_ty = bval.ty.clone();
81                    let bval = transform_constant_expr(span, bval, new_var);
82
83                    // Evaluate the referenced value
84                    let bvar = new_var(Rvalue::Use(bval), bval_ty);
85
86                    // Borrow the value
87                    let ref_var = new_var(Rvalue::Ref(bvar, BorrowKind::Shared), val.ty);
88
89                    Operand::Move(ref_var)
90                }
91            }
92        }
93        ConstantExprKind::Ptr(rk, bval) => {
94            match bval.value {
95                ConstantExprKind::Global(global_ref) => Operand::Move(new_var(
96                    Rvalue::RawPtr(Place::new_global(global_ref, bval.ty), rk),
97                    val.ty,
98                )),
99                _ => {
100                    // Recurse on the borrowed value
101                    let bval_ty = bval.ty.clone();
102                    let bval = transform_constant_expr(span, bval, new_var);
103
104                    // Evaluate the referenced value
105                    let bvar = new_var(Rvalue::Use(bval), bval_ty);
106
107                    // Borrow the value
108                    let ref_var = new_var(Rvalue::RawPtr(bvar, rk), val.ty);
109
110                    Operand::Move(ref_var)
111                }
112            }
113        }
114        ConstantExprKind::Adt(variant, fields) => {
115            let fields = fields
116                .into_iter()
117                .map(|x| transform_constant_expr(span, Box::new(x), new_var))
118                .collect();
119
120            // Build an `Aggregate` rvalue.
121            let rval = {
122                let tref = val.ty.kind().as_adt().unwrap();
123                let aggregate_kind = AggregateKind::Adt(tref.clone(), variant, None);
124                Rvalue::Aggregate(aggregate_kind, fields)
125            };
126            let var = new_var(rval, val.ty);
127
128            Operand::Move(var)
129        }
130        ConstantExprKind::Array(fields) => {
131            let fields = fields
132                .into_iter()
133                .map(|x| transform_constant_expr(span, Box::new(x), new_var))
134                .collect_vec();
135
136            let len = ConstGeneric::Value(Literal::Scalar(ScalarValue::Unsigned(
137                UIntTy::Usize,
138                fields.len() as u128,
139            )));
140            let tref = val.ty.kind().as_adt().unwrap();
141            assert_matches!(
142                *tref.id.as_builtin().unwrap(),
143                BuiltinTy::Array | BuiltinTy::Slice
144            );
145            let ty = tref.generics.types[0].clone();
146            let rval = Rvalue::Aggregate(AggregateKind::Array(ty, len), fields);
147            let var = new_var(rval, val.ty);
148
149            Operand::Move(var)
150        }
151    }
152}
153
154fn transform_operand(span: &Span, locals: &mut Locals, nst: &mut Vec<Statement>, op: &mut Operand) {
155    // Transform the constant operands (otherwise do nothing)
156    take_mut::take(op, |op| {
157        if let Operand::Const(val) = op {
158            let mut new_var = |rvalue, ty| {
159                if let Rvalue::Use(Operand::Move(place)) = rvalue {
160                    place
161                } else {
162                    let var = locals.new_var(None, ty);
163                    nst.push(Statement::new(
164                        *span,
165                        StatementKind::Assign(var.clone(), rvalue),
166                    ));
167                    var
168                }
169            };
170            transform_constant_expr(span, val, &mut new_var)
171        } else {
172            op
173        }
174    })
175}
176
177pub struct Transform;
178impl UllbcPass for Transform {
179    fn transform_body(&self, _ctx: &mut TransformCtx, body: &mut ExprBody) {
180        for block in body.body.iter_mut() {
181            // Deconstruct some constants into series of MIR assignments.
182            block.transform_operands(|span, nst, op| {
183                transform_operand(span, &mut body.locals, nst, op)
184            });
185
186            // Simplify array with repeated constants into array repeats.
187            block.dyn_visit_in_body_mut(|rvalue: &mut Rvalue| {
188                take_mut::take(rvalue, |rvalue| match rvalue {
189                    Rvalue::Aggregate(AggregateKind::Array(ty, len), ref fields)
190                        if fields.len() >= 2
191                            && fields.iter().all(|x| x.is_const())
192                            && let Ok(op) = fields.iter().dedup().exactly_one() =>
193                    {
194                        Rvalue::Repeat(op.clone(), ty.clone(), len)
195                    }
196                    _ => rvalue,
197                });
198            });
199        }
200    }
201}