charon_lib/transform/
remove_dynamic_checks.rs

1//! # Micro-pass: remove the dynamic checks for array/slice bounds, overflow, and division by zero.
2//! Note that from a semantic point of view, an out-of-bound access or a division by zero
3//! must lead to a panic in Rust (which is why those checks are always present, even when
4//! compiling for release). In our case, we take this into account in the semantics of our
5//! array/slice manipulation and arithmetic functions, on the verification side.
6
7use derive_generic_visitor::*;
8
9use crate::ast::*;
10use crate::transform::TransformCtx;
11use crate::ullbc_ast::{ExprBody, RawStatement, Statement};
12
13use super::ctx::UllbcPass;
14
15/// Whether the value uses the given local in a place.
16fn uses_local<T: BodyVisitable>(x: &T, local: LocalId) -> bool {
17    struct FoundIt;
18    struct UsesLocalVisitor(LocalId);
19
20    impl Visitor for UsesLocalVisitor {
21        type Break = FoundIt;
22    }
23    impl VisitBody for UsesLocalVisitor {
24        fn visit_place(&mut self, x: &Place) -> ::std::ops::ControlFlow<Self::Break> {
25            if let Some(local_id) = x.as_local() {
26                if local_id == self.0 {
27                    return ControlFlow::Break(FoundIt);
28                }
29            }
30            self.visit_inner(x)
31        }
32    }
33
34    x.drive_body(&mut UsesLocalVisitor(local)).is_break()
35}
36
37fn make_binop_overflow_panic<T: BodyVisitable>(
38    x: &mut [T],
39    matches: impl Fn(&BinOp, &Operand, &Operand) -> bool,
40) -> bool {
41    let mut found = false;
42    for y in x.iter_mut() {
43        y.dyn_visit_in_body_mut(|rv: &mut Rvalue| {
44            if let Rvalue::BinaryOp(binop, op_l, op_r) = rv
45                && matches(binop, op_l, op_r)
46            {
47                *binop = binop.with_overflow(OverflowMode::Panic);
48                found = true;
49            }
50        });
51    }
52    found
53}
54
55fn make_unop_overflow_panic<T: BodyVisitable>(
56    x: &mut [T],
57    matches: impl Fn(&UnOp, &Operand) -> bool,
58) -> bool {
59    let mut found = false;
60    for y in x.iter_mut() {
61        y.dyn_visit_in_body_mut(|rv: &mut Rvalue| {
62            if let Rvalue::UnaryOp(unop, op) = rv
63                && matches(unop, op)
64            {
65                *unop = unop.with_overflow(OverflowMode::Panic);
66                found = true;
67            }
68        });
69    }
70    found
71}
72
73/// Check if the two operands are equivalent: either they're the same constant, or they represent
74/// the same place (regardless of whether the operand is a move or a copy)
75fn equiv_op(op_l: &Operand, op_r: &Operand) -> bool {
76    match (op_l, op_r) {
77        (Operand::Copy(l) | Operand::Move(l), Operand::Copy(r) | Operand::Move(r)) => l == r,
78        (Operand::Const(l), Operand::Const(r)) => l == r,
79        _ => false,
80    }
81}
82
83/// Rustc inserts dynamic checks during MIR lowering. They all end in an `Assert` statement (and
84/// this is the only use of this statement).
85fn remove_dynamic_checks(
86    _ctx: &mut TransformCtx,
87    locals: &mut Locals,
88    statements: &mut [Statement],
89) {
90    // We return the statements we want to keep, which must be a prefix of `block.statements`.
91    let statements_to_keep = match statements {
92        // Bounds checks for slices. They look like:
93        //   l := ptr_metadata(copy a)
94        //   b := copy x < copy l
95        //   assert(move b == true)
96        [
97            Statement {
98                content:
99                    RawStatement::Assign(len, Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(len_op))),
100                ..
101            },
102            Statement {
103                content:
104                    RawStatement::Assign(
105                        is_in_bounds,
106                        Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
107                    ),
108                ..
109            },
110            Statement {
111                content:
112                    RawStatement::Assert(Assert {
113                        cond: Operand::Move(cond),
114                        expected: true,
115                        ..
116                    }),
117                ..
118            },
119            rest @ ..,
120        ] if lt_op2 == len && cond == is_in_bounds && len_op.ty().is_ref() => rest,
121        // Sometimes that instead looks like:
122        //   a := &raw const *z
123        //   l := ptr_metadata(move a)
124        //   b := copy x < copy l
125        //   assert(move b == true)
126        [
127            Statement {
128                content: RawStatement::Assign(reborrow, Rvalue::RawPtr(_, RefKind::Shared)),
129                ..
130            },
131            Statement {
132                content:
133                    RawStatement::Assign(len, Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(len_op))),
134                ..
135            },
136            Statement {
137                content:
138                    RawStatement::Assign(
139                        is_in_bounds,
140                        Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
141                    ),
142                ..
143            },
144            Statement {
145                content:
146                    RawStatement::Assert(Assert {
147                        cond: Operand::Move(cond),
148                        expected: true,
149                        ..
150                    }),
151                ..
152            },
153            rest @ ..,
154        ] if reborrow == len_op && lt_op2 == len && cond == is_in_bounds => rest,
155
156        // Zero checks for division and remainder. They look like:
157        //   b := copy y == const 0
158        //   assert(move b == false)
159        //   ...
160        //   res := x {/,%} move y;
161        //   ... or ...
162        //   b := const y == const 0
163        //   assert(move b == false)
164        //   ...
165        //   res := x {/,%} const y;
166        //
167        // This also overlaps with overflow checks for negation, which looks like:
168        //   is_min := x == INT::min
169        //   assert(move is_min == false)
170        //   ...
171        //   res := -x;
172        [
173            Statement {
174                content:
175                    RawStatement::Assign(
176                        is_zero,
177                        Rvalue::BinaryOp(BinOp::Eq, y_op, Operand::Const(_zero)),
178                    ),
179                ..
180            },
181            Statement {
182                content:
183                    RawStatement::Assert(Assert {
184                        cond: Operand::Move(cond),
185                        expected: false,
186                        ..
187                    }),
188                ..
189            },
190            rest @ ..,
191        ] if cond == is_zero => {
192            let found = make_binop_overflow_panic(rest, |bop, _, r| {
193                matches!(bop, BinOp::Div(_) | BinOp::Rem(_)) && equiv_op(r, y_op)
194            }) || make_unop_overflow_panic(rest, |unop, o| {
195                matches!(unop, UnOp::Neg(_)) && equiv_op(o, y_op)
196            });
197            if found {
198                rest
199            } else {
200                return;
201            }
202        }
203
204        // Overflow checks for signed division and remainder. They look like:
205        //   is_neg_1 := y == (-1)
206        //   is_min := x == INT::min
207        //   has_overflow := move (is_neg_1) & move (is_min)
208        //   assert(move has_overflow == false)
209        // Note here we don't need to update the operand to panic, as this was already done
210        // by the previous pass for division by zero.
211        [
212            Statement {
213                content:
214                    RawStatement::Assign(is_neg_1, Rvalue::BinaryOp(BinOp::Eq, _y_op, _minus_1)),
215                ..
216            },
217            Statement {
218                content: RawStatement::Assign(is_min, Rvalue::BinaryOp(BinOp::Eq, _x_op, _int_min)),
219                ..
220            },
221            Statement {
222                content:
223                    RawStatement::Assign(
224                        has_overflow,
225                        Rvalue::BinaryOp(
226                            BinOp::BitAnd,
227                            Operand::Move(and_op1),
228                            Operand::Move(and_op2),
229                        ),
230                    ),
231                ..
232            },
233            Statement {
234                content:
235                    RawStatement::Assert(Assert {
236                        cond: Operand::Move(cond),
237                        expected: false,
238                        ..
239                    }),
240                ..
241            },
242            rest @ ..,
243        ] if and_op1 == is_neg_1 && and_op2 == is_min && cond == has_overflow => rest,
244
245        // Overflow checks for right/left shift. They can look like:
246        //   a := y as u32; // or another type
247        //   b := move a < const 32; // or another constant
248        //   assert(move b == true);
249        //   ...
250        //   res := x {<<,>>} y;
251        [
252            Statement {
253                content: RawStatement::Assign(cast, Rvalue::UnaryOp(UnOp::Cast(_), y_op)),
254                ..
255            },
256            Statement {
257                content:
258                    RawStatement::Assign(
259                        has_overflow,
260                        Rvalue::BinaryOp(BinOp::Lt, Operand::Move(lhs), Operand::Const(..)),
261                    ),
262                ..
263            },
264            Statement {
265                content:
266                    RawStatement::Assert(Assert {
267                        cond: Operand::Move(cond),
268                        expected: true,
269                        ..
270                    }),
271                ..
272            },
273            rest @ ..,
274        ] if cond == has_overflow
275            && lhs == cast
276            && let Some(cast_local) = cast.as_local()
277            && !rest.iter().any(|st| uses_local(st, cast_local)) =>
278        {
279            let found = make_binop_overflow_panic(rest, |bop, _, r| {
280                matches!(bop, BinOp::Shl(_) | BinOp::Shr(_)) && equiv_op(r, y_op)
281            });
282            if found {
283                rest
284            } else {
285                return;
286            }
287        }
288        // or like:
289        //   b := y < const 32; // or another constant
290        //   assert(move b == true);
291        //   ...
292        //   res := x {<<,>>} y;
293        //
294        // this also overlaps with out of bounds checks for arrays, so we check for either;
295        // these look like:
296        //   b := copy y < const _
297        //   assert(move b == true)
298        //   ...
299        //   res := a[y];
300        [
301            Statement {
302                content:
303                    RawStatement::Assign(
304                        has_overflow,
305                        Rvalue::BinaryOp(BinOp::Lt, y_op, Operand::Const(..)),
306                    ),
307                ..
308            },
309            Statement {
310                content:
311                    RawStatement::Assert(Assert {
312                        cond: Operand::Move(cond),
313                        expected: true,
314                        ..
315                    }),
316                ..
317            },
318            rest @ ..,
319        ] if cond == has_overflow => {
320            // look for a shift operation
321            let mut found = make_binop_overflow_panic(rest, |bop, _, r| {
322                matches!(bop, BinOp::Shl(_) | BinOp::Shr(_)) && equiv_op(r, y_op)
323            });
324            if !found {
325                // otherwise, look for an array access
326                for stmt in rest.iter_mut() {
327                    stmt.dyn_visit_in_body(|p: &Place| {
328                        if let Some((_, ProjectionElem::Index { offset, .. })) = p.as_projection()
329                            && equiv_op(offset, y_op)
330                        {
331                            found = true;
332                        }
333                    });
334                }
335            }
336
337            if found {
338                rest
339            } else {
340                return;
341            }
342        }
343
344        // Overflow checks for addition/subtraction/multiplication. They look like:
345        // ```text
346        //   r := x checked.+ y;
347        //   assert(move r.1 == false);
348        //   ...
349        //   z := move r.0;
350        // ```
351        // We replace that with:
352        // ```text
353        // z := x + y;
354        // ```
355        //
356        // But sometimes, because of constant promotion, we end up with a lone checked operation
357        // without assert. In that case we replace it with its wrapping equivalent.
358        [
359            Statement {
360                content:
361                    RawStatement::Assign(
362                        result,
363                        Rvalue::BinaryOp(
364                            binop @ (BinOp::AddChecked | BinOp::SubChecked | BinOp::MulChecked),
365                            _,
366                            _,
367                        ),
368                    ),
369                ..
370            },
371            rest @ ..,
372        ] if let Some(result_local_id) = result.as_local() => {
373            // Look for uses of the overflow boolean.
374            let mut overflow_is_used = false;
375            for stmt in rest.iter_mut() {
376                stmt.dyn_visit_in_body(|p: &Place| {
377                    if let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
378                        p.as_projection()
379                        && fid.index() == 1
380                        && sub == result
381                    {
382                        overflow_is_used = true;
383                    }
384                });
385            }
386            // Check if the operation is followed by an assert.
387            let followed_by_assert = if let [
388                Statement {
389                    content:
390                        RawStatement::Assert(Assert {
391                            cond: Operand::Move(assert_cond),
392                            expected: false,
393                            ..
394                        }),
395                    ..
396                },
397                ..,
398            ] = rest
399                && let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
400                    assert_cond.as_projection()
401                && fid.index() == 1
402                && sub == result
403            {
404                true
405            } else {
406                false
407            };
408            if overflow_is_used && !followed_by_assert {
409                // The overflow boolean is used in a way that isn't a builtin overflow check; we
410                // change nothing.
411                return;
412            }
413
414            if followed_by_assert {
415                // We have a compiler-emitted assert. We replace the operation with one that has
416                // panic-on-overflow semantics.
417                *binop = binop.with_overflow(OverflowMode::Panic);
418                // The failure behavior is part of the binop now, so we remove the assert.
419                rest[0].content = RawStatement::Nop;
420            } else {
421                // The overflow boolean is not used, we replace the operations with wrapping
422                // semantics.
423                *binop = binop.with_overflow(OverflowMode::Wrap);
424            }
425            // Fixup the local type.
426            let result_local = &mut locals.locals[result_local_id];
427            result_local.ty = result_local.ty.as_tuple().unwrap()[0].clone();
428            // Fixup the place type.
429            let new_result_place = locals.place_for_var(result_local_id);
430            // Replace uses of `r.0` with `r`.
431            for stmt in rest.iter_mut() {
432                stmt.dyn_visit_in_body_mut(|p: &mut Place| {
433                    if let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
434                        p.as_projection()
435                        && sub == result
436                    {
437                        assert_eq!(fid.index(), 0);
438                        *p = new_result_place.clone()
439                    }
440                });
441            }
442            *result = new_result_place;
443            return;
444        }
445
446        _ => return,
447    };
448
449    // Remove the statements we're not keeping.
450    let keep_len = statements_to_keep.len();
451    for i in 0..statements.len() - keep_len {
452        statements[i].content = RawStatement::Nop;
453    }
454}
455
456pub struct Transform;
457impl UllbcPass for Transform {
458    fn transform_body(&self, ctx: &mut TransformCtx, b: &mut ExprBody) {
459        b.transform_sequences_fwd(|locals, seq| {
460            remove_dynamic_checks(ctx, locals, seq);
461            Vec::new()
462        });
463    }
464}