charon_lib/transform/
simplify_constants.rs

1//! The MIR constant expressions lead to a lot of duplication: there are
2//! for instance constant ADTs which duplicate the "regular" aggregated
3//! ADTs in the operands, constant references, etc. This reduces the number
4//! of cases to handle and eases the function translation in Aeneas.
5//!
6//! This pass removes all those occurrences so that only the
7//! `ConstantExpression::Literal`. It does so by introducing intermediate statements.
8//!
9//! A small remark about the intermediate statements we introduce for the globals:
10//! we do so because, when evaluating the code in "concrete" mode, it allows to
11//! handle the globals like function calls.
12
13use itertools::Itertools;
14use std::assert_matches::assert_matches;
15
16use crate::transform::TransformCtx;
17use crate::ullbc_ast::*;
18
19use super::ctx::UllbcPass;
20
21/// If the constant value is a constant ADT, push `Assign::Aggregate` statements
22/// to the vector of statements, that bind new variables to the ADT parts and
23/// the variable assigned to the complete ADT.
24///
25/// Goes fom e.g. `f(T::A(x, y))` to `let a = T::A(x, y); f(a)`.
26/// The function is recursively called on the aggregate fields (e.g. here x and y).
27fn transform_constant_expr(
28    span: &Span,
29    val: Box<ConstantExpr>,
30    new_var: &mut impl FnMut(Rvalue, Ty) -> Place,
31) -> Operand {
32    match val.value {
33        RawConstantExpr::Literal(_)
34        | RawConstantExpr::Var(_)
35        | RawConstantExpr::RawMemory(..)
36        | RawConstantExpr::TraitConst(..)
37        | RawConstantExpr::FnPtr(..)
38        | RawConstantExpr::Opaque(_) => {
39            // Nothing to do
40            // TODO: for trait const: might come from a top-level impl, so we might
41            // want to introduce an intermediate statement to be able to evaluate
42            // it as a function call, like for globals.
43            Operand::Const(val)
44        }
45        // Here we use a copy, rather than a move -- moving a global would leave it uninitialized,
46        // which would e.g. make the following code fail:
47        //     const GLOBAL: usize = 0;
48        //     let x = GLOBAL;
49        //     let y = GLOBAL; // if moving, at this point GLOBAL would be uninitialized
50        RawConstantExpr::Global(global_ref) => Operand::Copy(Place::new_global(global_ref, val.ty)),
51        RawConstantExpr::PtrNoProvenance(ptr) => {
52            let usize_ty = TyKind::Literal(LiteralTy::UInt(UIntTy::Usize)).into_ty();
53            let ptr_usize = RawConstantExpr::Literal(Literal::Scalar(ScalarValue::Unsigned(
54                UIntTy::Usize,
55                ptr,
56            )));
57            let cast = UnOp::Cast(CastKind::RawPtr(usize_ty.clone(), val.ty.clone()));
58            let uvar = new_var(
59                Rvalue::UnaryOp(
60                    cast,
61                    Operand::Const(Box::new(ConstantExpr {
62                        value: ptr_usize,
63                        ty: usize_ty,
64                    })),
65                ),
66                val.ty,
67            );
68            Operand::Move(uvar)
69        }
70        RawConstantExpr::Ref(bval) => {
71            match bval.value {
72                RawConstantExpr::Global(global_ref) => Operand::Move(new_var(
73                    Rvalue::Ref(Place::new_global(global_ref, bval.ty), BorrowKind::Shared),
74                    val.ty,
75                )),
76                _ => {
77                    // Recurse on the borrowed value
78                    let bval_ty = bval.ty.clone();
79                    let bval = transform_constant_expr(span, bval, new_var);
80
81                    // Evaluate the referenced value
82                    let bvar = new_var(Rvalue::Use(bval), bval_ty);
83
84                    // Borrow the value
85                    let ref_var = new_var(Rvalue::Ref(bvar, BorrowKind::Shared), val.ty);
86
87                    Operand::Move(ref_var)
88                }
89            }
90        }
91        RawConstantExpr::Ptr(rk, bval) => {
92            match bval.value {
93                RawConstantExpr::Global(global_ref) => Operand::Move(new_var(
94                    Rvalue::RawPtr(Place::new_global(global_ref, bval.ty), rk),
95                    val.ty,
96                )),
97                _ => {
98                    // Recurse on the borrowed value
99                    let bval_ty = bval.ty.clone();
100                    let bval = transform_constant_expr(span, bval, new_var);
101
102                    // Evaluate the referenced value
103                    let bvar = new_var(Rvalue::Use(bval), bval_ty);
104
105                    // Borrow the value
106                    let ref_var = new_var(Rvalue::RawPtr(bvar, rk), val.ty);
107
108                    Operand::Move(ref_var)
109                }
110            }
111        }
112        RawConstantExpr::Adt(variant, fields) => {
113            let fields = fields
114                .into_iter()
115                .map(|x| transform_constant_expr(span, Box::new(x), new_var))
116                .collect();
117
118            // Build an `Aggregate` rvalue.
119            let rval = {
120                let tref = val.ty.kind().as_adt().unwrap();
121                let aggregate_kind = AggregateKind::Adt(tref.clone(), variant, None);
122                Rvalue::Aggregate(aggregate_kind, fields)
123            };
124            let var = new_var(rval, val.ty);
125
126            Operand::Move(var)
127        }
128        RawConstantExpr::Array(fields) => {
129            let fields = fields
130                .into_iter()
131                .map(|x| transform_constant_expr(span, Box::new(x), new_var))
132                .collect_vec();
133
134            let len = ConstGeneric::Value(Literal::Scalar(ScalarValue::Unsigned(
135                UIntTy::Usize,
136                fields.len() as u128,
137            )));
138            let tref = val.ty.kind().as_adt().unwrap();
139            assert_matches!(
140                *tref.id.as_builtin().unwrap(),
141                BuiltinTy::Array | BuiltinTy::Slice
142            );
143            let ty = tref.generics.types[0].clone();
144            let rval = Rvalue::Aggregate(AggregateKind::Array(ty, len), fields);
145            let var = new_var(rval, val.ty);
146
147            Operand::Move(var)
148        }
149    }
150}
151
152fn transform_operand(span: &Span, locals: &mut Locals, nst: &mut Vec<Statement>, op: &mut Operand) {
153    // Transform the constant operands (otherwise do nothing)
154    take_mut::take(op, |op| {
155        if let Operand::Const(val) = op {
156            let mut new_var = |rvalue, ty| {
157                if let Rvalue::Use(Operand::Move(place)) = rvalue {
158                    place
159                } else {
160                    let var = locals.new_var(None, ty);
161                    nst.push(Statement::new(
162                        *span,
163                        RawStatement::Assign(var.clone(), rvalue),
164                    ));
165                    var
166                }
167            };
168            transform_constant_expr(span, val, &mut new_var)
169        } else {
170            op
171        }
172    })
173}
174
175pub struct Transform;
176impl UllbcPass for Transform {
177    fn transform_body(&self, _ctx: &mut TransformCtx, body: &mut ExprBody) {
178        for block in body.body.iter_mut() {
179            // Deconstruct some constants into series of MIR assignments.
180            block.transform_operands(|span, nst, op| {
181                transform_operand(span, &mut body.locals, nst, op)
182            });
183
184            // Simplify array with repeated constants into array repeats.
185            block.dyn_visit_in_body_mut(|rvalue: &mut Rvalue| {
186                take_mut::take(rvalue, |rvalue| match rvalue {
187                    Rvalue::Aggregate(AggregateKind::Array(ty, len), ref fields)
188                        if fields.len() >= 2
189                            && fields.iter().all(|x| x.is_const())
190                            && let Ok(op) = fields.iter().dedup().exactly_one() =>
191                    {
192                        Rvalue::Repeat(op.clone(), ty.clone(), len)
193                    }
194                    _ => rvalue,
195                });
196            });
197        }
198    }
199}