charon_lib/transform/
simplify_constants.rs

1//! The MIR constant expressions lead to a lot of duplication: there are
2//! for instance constant ADTs which duplicate the "regular" aggregated
3//! ADTs in the operands, constant references, etc. This reduces the number
4//! of cases to handle and eases the function translation in Aeneas.
5//!
6//! This pass removes all those occurrences so that only the
7//! [ConstantExpression::Literal]. It does so by introducing intermediate statements.
8//!
9//! A small remark about the intermediate statements we introduce for the globals:
10//! we do so because, when evaluating the code in "concrete" mode, it allows to
11//! handle the globals like function calls.
12
13use itertools::Itertools;
14use std::assert_matches::assert_matches;
15
16use crate::transform::TransformCtx;
17use crate::ullbc_ast::*;
18
19use super::ctx::UllbcPass;
20
21/// If the constant value is a constant ADT, push `Assign::Aggregate` statements
22/// to the vector of statements, that bind new variables to the ADT parts and
23/// the variable assigned to the complete ADT.
24///
25/// Goes fom e.g. `f(T::A(x, y))` to `let a = T::A(x, y); f(a)`.
26/// The function is recursively called on the aggregate fields (e.g. here x and y).
27fn transform_constant_expr(
28    span: &Span,
29    val: ConstantExpr,
30    new_var: &mut impl FnMut(Rvalue, Ty) -> Place,
31) -> Operand {
32    match val.value {
33        RawConstantExpr::Literal(_)
34        | RawConstantExpr::Var(_)
35        | RawConstantExpr::RawMemory(..)
36        | RawConstantExpr::TraitConst(..)
37        | RawConstantExpr::FnPtr(..)
38        | RawConstantExpr::Opaque(_) => {
39            // Nothing to do
40            // TODO: for trait const: might come from a top-level impl, so we might
41            // want to introduce an intermediate statement to be able to evaluate
42            // it as a function call, like for globals.
43            Operand::Const(val)
44        }
45        RawConstantExpr::Global(global_ref) => {
46            Operand::Move(new_var(Rvalue::Global(global_ref), val.ty.clone()))
47        }
48        RawConstantExpr::Ref(box bval) => {
49            match bval.value {
50                RawConstantExpr::Global(global_ref) => Operand::Move(new_var(
51                    Rvalue::GlobalRef(global_ref, RefKind::Shared),
52                    val.ty,
53                )),
54                _ => {
55                    // Recurse on the borrowed value
56                    let bval_ty = bval.ty.clone();
57                    let bval = transform_constant_expr(span, bval, new_var);
58
59                    // Evaluate the referenced value
60                    let bvar = new_var(Rvalue::Use(bval), bval_ty);
61
62                    // Borrow the value
63                    let ref_var = new_var(Rvalue::Ref(bvar, BorrowKind::Shared), val.ty);
64
65                    Operand::Move(ref_var)
66                }
67            }
68        }
69        RawConstantExpr::MutPtr(box bval) => {
70            match bval.value {
71                RawConstantExpr::Global(global_ref) => {
72                    Operand::Move(new_var(Rvalue::GlobalRef(global_ref, RefKind::Mut), val.ty))
73                }
74                _ => {
75                    // Recurse on the borrowed value
76                    let bval_ty = bval.ty.clone();
77                    let bval = transform_constant_expr(span, bval, new_var);
78
79                    // Evaluate the referenced value
80                    let bvar = new_var(Rvalue::Use(bval), bval_ty);
81
82                    // Borrow the value
83                    let ref_var = new_var(Rvalue::RawPtr(bvar, RefKind::Mut), val.ty);
84
85                    Operand::Move(ref_var)
86                }
87            }
88        }
89        RawConstantExpr::Adt(variant, fields) => {
90            let fields = fields
91                .into_iter()
92                .map(|x| transform_constant_expr(span, x, new_var))
93                .collect();
94
95            // Build an `Aggregate` rvalue.
96            let rval = {
97                let (adt_kind, generics) = val.ty.kind().as_adt().unwrap();
98                let aggregate_kind = AggregateKind::Adt(*adt_kind, variant, None, generics.clone());
99                Rvalue::Aggregate(aggregate_kind, fields)
100            };
101            let var = new_var(rval, val.ty);
102
103            Operand::Move(var)
104        }
105        RawConstantExpr::Array(fields) => {
106            let fields = fields
107                .into_iter()
108                .map(|x| transform_constant_expr(span, x, new_var))
109                .collect_vec();
110
111            let len = ConstGeneric::Value(Literal::Scalar(ScalarValue::Usize(fields.len() as u64)));
112            let (adt_kind, generics) = val.ty.kind().as_adt().unwrap();
113            assert_matches!(
114                *adt_kind.as_builtin().unwrap(),
115                BuiltinTy::Array | BuiltinTy::Slice
116            );
117            let ty = generics.types[0].clone();
118            let rval = Rvalue::Aggregate(AggregateKind::Array(ty, len), fields);
119            let var = new_var(rval, val.ty);
120
121            Operand::Move(var)
122        }
123    }
124}
125
126fn transform_operand(span: &Span, locals: &mut Locals, nst: &mut Vec<Statement>, op: &mut Operand) {
127    // Transform the constant operands (otherwise do nothing)
128    take_mut::take(op, |op| {
129        if let Operand::Const(val) = op {
130            let mut new_var = |rvalue, ty| {
131                if let Rvalue::Use(Operand::Move(place)) = rvalue {
132                    place
133                } else {
134                    let var = locals.new_var(None, ty);
135                    nst.push(Statement::new(
136                        *span,
137                        RawStatement::Assign(var.clone(), rvalue),
138                    ));
139                    var
140                }
141            };
142            transform_constant_expr(span, val, &mut new_var)
143        } else {
144            op
145        }
146    })
147}
148
149pub struct Transform;
150impl UllbcPass for Transform {
151    fn transform_body(&self, _ctx: &mut TransformCtx, body: &mut ExprBody) {
152        for block in body.body.iter_mut() {
153            // Deconstruct some constants into series of MIR assignments.
154            block.transform_operands(|span, nst, op| {
155                transform_operand(span, &mut body.locals, nst, op)
156            });
157
158            // Simplify array with repeated constants into array repeats.
159            block.dyn_visit_in_body_mut(|rvalue: &mut Rvalue| {
160                take_mut::take(rvalue, |rvalue| match rvalue {
161                    Rvalue::Aggregate(AggregateKind::Array(ty, len), ref fields)
162                        if fields.len() >= 2
163                            && fields.iter().all(|x| x.is_const())
164                            && let Ok(op) = fields.iter().dedup().exactly_one() =>
165                    {
166                        Rvalue::Repeat(op.clone(), ty.clone(), len)
167                    }
168                    _ => rvalue,
169                });
170            });
171        }
172    }
173}