charon_lib/transform/
remove_dynamic_checks.rs

1//! # Micro-pass: remove the dynamic checks for array/slice bounds, overflow, and division by zero.
2//! Note that from a semantic point of view, an out-of-bound access or a division by zero
3//! must lead to a panic in Rust (which is why those checks are always present, even when
4//! compiling for release). In our case, we take this into account in the semantics of our
5//! array/slice manipulation and arithmetic functions, on the verification side.
6
7use derive_generic_visitor::*;
8
9use crate::ast::*;
10use crate::transform::TransformCtx;
11use crate::ullbc_ast::{ExprBody, Statement, StatementKind};
12
13use super::ctx::UllbcPass;
14
15/// Whether the value uses the given local in a place.
16fn uses_local<T: BodyVisitable>(x: &T, local: LocalId) -> bool {
17    struct FoundIt;
18    struct UsesLocalVisitor(LocalId);
19
20    impl Visitor for UsesLocalVisitor {
21        type Break = FoundIt;
22    }
23    impl VisitBody for UsesLocalVisitor {
24        fn visit_place(&mut self, x: &Place) -> ::std::ops::ControlFlow<Self::Break> {
25            if let Some(local_id) = x.as_local() {
26                if local_id == self.0 {
27                    return ControlFlow::Break(FoundIt);
28                }
29            }
30            self.visit_inner(x)
31        }
32    }
33
34    x.drive_body(&mut UsesLocalVisitor(local)).is_break()
35}
36
37fn make_binop_overflow_panic<T: BodyVisitable>(
38    x: &mut [T],
39    matches: impl Fn(&BinOp, &Operand, &Operand) -> bool,
40) -> bool {
41    let mut found = false;
42    for y in x.iter_mut() {
43        y.dyn_visit_in_body_mut(|rv: &mut Rvalue| {
44            if let Rvalue::BinaryOp(binop, op_l, op_r) = rv
45                && matches(binop, op_l, op_r)
46            {
47                *binop = binop.with_overflow(OverflowMode::Panic);
48                found = true;
49            }
50        });
51    }
52    found
53}
54
55fn make_unop_overflow_panic<T: BodyVisitable>(
56    x: &mut [T],
57    matches: impl Fn(&UnOp, &Operand) -> bool,
58) -> bool {
59    let mut found = false;
60    for y in x.iter_mut() {
61        y.dyn_visit_in_body_mut(|rv: &mut Rvalue| {
62            if let Rvalue::UnaryOp(unop, op) = rv
63                && matches(unop, op)
64            {
65                *unop = unop.with_overflow(OverflowMode::Panic);
66                found = true;
67            }
68        });
69    }
70    found
71}
72
73/// Check if the two operands are equivalent: either they're the same constant, or they represent
74/// the same place (regardless of whether the operand is a move or a copy)
75fn equiv_op(op_l: &Operand, op_r: &Operand) -> bool {
76    match (op_l, op_r) {
77        (Operand::Copy(l) | Operand::Move(l), Operand::Copy(r) | Operand::Move(r)) => l == r,
78        (Operand::Const(l), Operand::Const(r)) => l == r,
79        _ => false,
80    }
81}
82
83/// Rustc inserts dynamic checks during MIR lowering. They all end in an `Assert` statement (and
84/// this is the only use of this statement).
85fn remove_dynamic_checks(
86    _ctx: &mut TransformCtx,
87    locals: &mut Locals,
88    statements: &mut [Statement],
89) {
90    // We return the statements we want to keep, which must be a prefix of `block.statements`.
91    let statements_to_keep = match statements {
92        // Bounds checks for slices. They look like:
93        //   l := ptr_metadata(copy a)
94        //   b := copy x < copy l
95        //   assert(move b == true)
96        [
97            Statement {
98                content:
99                    StatementKind::Assign(
100                        len,
101                        Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(len_op)),
102                    ),
103                ..
104            },
105            Statement {
106                content:
107                    StatementKind::Assign(
108                        is_in_bounds,
109                        Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
110                    ),
111                ..
112            },
113            Statement {
114                content:
115                    StatementKind::Assert(Assert {
116                        cond: Operand::Move(cond),
117                        expected: true,
118                        ..
119                    }),
120                ..
121            },
122            rest @ ..,
123        ] if lt_op2 == len && cond == is_in_bounds && len_op.ty().is_ref() => rest,
124        // Sometimes that instead looks like:
125        //   a := &raw const *z
126        //   l := ptr_metadata(move a)
127        //   b := copy x < copy l
128        //   assert(move b == true)
129        [
130            Statement {
131                content: StatementKind::Assign(reborrow, Rvalue::RawPtr(_, RefKind::Shared)),
132                ..
133            },
134            Statement {
135                content:
136                    StatementKind::Assign(
137                        len,
138                        Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(len_op)),
139                    ),
140                ..
141            },
142            Statement {
143                content:
144                    StatementKind::Assign(
145                        is_in_bounds,
146                        Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
147                    ),
148                ..
149            },
150            Statement {
151                content:
152                    StatementKind::Assert(Assert {
153                        cond: Operand::Move(cond),
154                        expected: true,
155                        ..
156                    }),
157                ..
158            },
159            rest @ ..,
160        ] if reborrow == len_op && lt_op2 == len && cond == is_in_bounds => rest,
161
162        // Zero checks for division and remainder. They look like:
163        //   b := copy y == const 0
164        //   assert(move b == false)
165        //   ...
166        //   res := x {/,%} move y;
167        //   ... or ...
168        //   b := const y == const 0
169        //   assert(move b == false)
170        //   ...
171        //   res := x {/,%} const y;
172        //
173        // This also overlaps with overflow checks for negation, which looks like:
174        //   is_min := x == INT::min
175        //   assert(move is_min == false)
176        //   ...
177        //   res := -x;
178        [
179            Statement {
180                content:
181                    StatementKind::Assign(
182                        is_zero,
183                        Rvalue::BinaryOp(BinOp::Eq, y_op, Operand::Const(_zero)),
184                    ),
185                ..
186            },
187            Statement {
188                content:
189                    StatementKind::Assert(Assert {
190                        cond: Operand::Move(cond),
191                        expected: false,
192                        ..
193                    }),
194                ..
195            },
196            rest @ ..,
197        ] if cond == is_zero => {
198            let found = make_binop_overflow_panic(rest, |bop, _, r| {
199                matches!(bop, BinOp::Div(_) | BinOp::Rem(_)) && equiv_op(r, y_op)
200            }) || make_unop_overflow_panic(rest, |unop, o| {
201                matches!(unop, UnOp::Neg(_)) && equiv_op(o, y_op)
202            });
203            if found {
204                rest
205            } else {
206                return;
207            }
208        }
209
210        // Overflow checks for signed division and remainder. They look like:
211        //   is_neg_1 := y == (-1)
212        //   is_min := x == INT::min
213        //   has_overflow := move (is_neg_1) & move (is_min)
214        //   assert(move has_overflow == false)
215        // Note here we don't need to update the operand to panic, as this was already done
216        // by the previous pass for division by zero.
217        [
218            Statement {
219                content:
220                    StatementKind::Assign(is_neg_1, Rvalue::BinaryOp(BinOp::Eq, _y_op, _minus_1)),
221                ..
222            },
223            Statement {
224                content: StatementKind::Assign(is_min, Rvalue::BinaryOp(BinOp::Eq, _x_op, _int_min)),
225                ..
226            },
227            Statement {
228                content:
229                    StatementKind::Assign(
230                        has_overflow,
231                        Rvalue::BinaryOp(
232                            BinOp::BitAnd,
233                            Operand::Move(and_op1),
234                            Operand::Move(and_op2),
235                        ),
236                    ),
237                ..
238            },
239            Statement {
240                content:
241                    StatementKind::Assert(Assert {
242                        cond: Operand::Move(cond),
243                        expected: false,
244                        ..
245                    }),
246                ..
247            },
248            rest @ ..,
249        ] if and_op1 == is_neg_1 && and_op2 == is_min && cond == has_overflow => rest,
250
251        // Overflow checks for right/left shift. They can look like:
252        //   a := y as u32; // or another type
253        //   b := move a < const 32; // or another constant
254        //   assert(move b == true);
255        //   ...
256        //   res := x {<<,>>} y;
257        [
258            Statement {
259                content: StatementKind::Assign(cast, Rvalue::UnaryOp(UnOp::Cast(_), y_op)),
260                ..
261            },
262            Statement {
263                content:
264                    StatementKind::Assign(
265                        has_overflow,
266                        Rvalue::BinaryOp(BinOp::Lt, Operand::Move(lhs), Operand::Const(..)),
267                    ),
268                ..
269            },
270            Statement {
271                content:
272                    StatementKind::Assert(Assert {
273                        cond: Operand::Move(cond),
274                        expected: true,
275                        ..
276                    }),
277                ..
278            },
279            rest @ ..,
280        ] if cond == has_overflow
281            && lhs == cast
282            && let Some(cast_local) = cast.as_local()
283            && !rest.iter().any(|st| uses_local(st, cast_local)) =>
284        {
285            let found = make_binop_overflow_panic(rest, |bop, _, r| {
286                matches!(bop, BinOp::Shl(_) | BinOp::Shr(_)) && equiv_op(r, y_op)
287            });
288            if found {
289                rest
290            } else {
291                return;
292            }
293        }
294        // or like:
295        //   b := y < const 32; // or another constant
296        //   assert(move b == true);
297        //   ...
298        //   res := x {<<,>>} y;
299        //
300        // this also overlaps with out of bounds checks for arrays, so we check for either;
301        // these look like:
302        //   b := copy y < const _
303        //   assert(move b == true)
304        //   ...
305        //   res := a[y];
306        [
307            Statement {
308                content:
309                    StatementKind::Assign(
310                        has_overflow,
311                        Rvalue::BinaryOp(BinOp::Lt, y_op, Operand::Const(..)),
312                    ),
313                ..
314            },
315            Statement {
316                content:
317                    StatementKind::Assert(Assert {
318                        cond: Operand::Move(cond),
319                        expected: true,
320                        ..
321                    }),
322                ..
323            },
324            rest @ ..,
325        ] if cond == has_overflow => {
326            // look for a shift operation
327            let mut found = make_binop_overflow_panic(rest, |bop, _, r| {
328                matches!(bop, BinOp::Shl(_) | BinOp::Shr(_)) && equiv_op(r, y_op)
329            });
330            if !found {
331                // otherwise, look for an array access
332                for stmt in rest.iter_mut() {
333                    stmt.dyn_visit_in_body(|p: &Place| {
334                        if let Some((_, ProjectionElem::Index { offset, .. })) = p.as_projection()
335                            && equiv_op(offset, y_op)
336                        {
337                            found = true;
338                        }
339                    });
340                }
341            }
342
343            if found {
344                rest
345            } else {
346                return;
347            }
348        }
349
350        // Overflow checks for addition/subtraction/multiplication. They look like:
351        // ```text
352        //   r := x checked.+ y;
353        //   assert(move r.1 == false);
354        //   ...
355        //   z := move r.0;
356        // ```
357        // We replace that with:
358        // ```text
359        // z := x + y;
360        // ```
361        //
362        // But sometimes, because of constant promotion, we end up with a lone checked operation
363        // without assert. In that case we replace it with its wrapping equivalent.
364        [
365            Statement {
366                content:
367                    StatementKind::Assign(
368                        result,
369                        Rvalue::BinaryOp(
370                            binop @ (BinOp::AddChecked | BinOp::SubChecked | BinOp::MulChecked),
371                            _,
372                            _,
373                        ),
374                    ),
375                ..
376            },
377            rest @ ..,
378        ] if let Some(result_local_id) = result.as_local() => {
379            // Look for uses of the overflow boolean.
380            let mut overflow_is_used = false;
381            for stmt in rest.iter_mut() {
382                stmt.dyn_visit_in_body(|p: &Place| {
383                    if let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
384                        p.as_projection()
385                        && fid.index() == 1
386                        && sub == result
387                    {
388                        overflow_is_used = true;
389                    }
390                });
391            }
392            // Check if the operation is followed by an assert.
393            let followed_by_assert = if let [
394                Statement {
395                    content:
396                        StatementKind::Assert(Assert {
397                            cond: Operand::Move(assert_cond),
398                            expected: false,
399                            ..
400                        }),
401                    ..
402                },
403                ..,
404            ] = rest
405                && let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
406                    assert_cond.as_projection()
407                && fid.index() == 1
408                && sub == result
409            {
410                true
411            } else {
412                false
413            };
414            if overflow_is_used && !followed_by_assert {
415                // The overflow boolean is used in a way that isn't a builtin overflow check; we
416                // change nothing.
417                return;
418            }
419
420            if followed_by_assert {
421                // We have a compiler-emitted assert. We replace the operation with one that has
422                // panic-on-overflow semantics.
423                *binop = binop.with_overflow(OverflowMode::Panic);
424                // The failure behavior is part of the binop now, so we remove the assert.
425                rest[0].content = StatementKind::Nop;
426            } else {
427                // The overflow boolean is not used, we replace the operations with wrapping
428                // semantics.
429                *binop = binop.with_overflow(OverflowMode::Wrap);
430            }
431            // Fixup the local type.
432            let result_local = &mut locals.locals[result_local_id];
433            result_local.ty = result_local.ty.as_tuple().unwrap()[0].clone();
434            // Fixup the place type.
435            let new_result_place = locals.place_for_var(result_local_id);
436            // Replace uses of `r.0` with `r`.
437            for stmt in rest.iter_mut() {
438                stmt.dyn_visit_in_body_mut(|p: &mut Place| {
439                    if let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
440                        p.as_projection()
441                        && sub == result
442                    {
443                        assert_eq!(fid.index(), 0);
444                        *p = new_result_place.clone()
445                    }
446                });
447            }
448            *result = new_result_place;
449            return;
450        }
451
452        _ => return,
453    };
454
455    // Remove the statements we're not keeping.
456    let keep_len = statements_to_keep.len();
457    for i in 0..statements.len() - keep_len {
458        statements[i].content = StatementKind::Nop;
459    }
460}
461
462pub struct Transform;
463impl UllbcPass for Transform {
464    fn transform_body(&self, ctx: &mut TransformCtx, b: &mut ExprBody) {
465        b.transform_sequences_fwd(|locals, seq| {
466            remove_dynamic_checks(ctx, locals, seq);
467            Vec::new()
468        });
469    }
470}