charon_lib/transform/resugar/
reconstruct_fallible_operations.rs

1//! # Micro-pass: remove the dynamic checks for array/slice bounds, overflow, and division by zero.
2//! Note that from a semantic point of view, an out-of-bound access or a division by zero
3//! must lead to a panic in Rust (which is why those checks are always present, even when
4//! compiling for release). In our case, we take this into account in the semantics of our
5//! array/slice manipulation and arithmetic functions, on the verification side.
6
7use std::collections::HashSet;
8
9use derive_generic_visitor::*;
10
11use crate::ast::*;
12use crate::ids::IndexVec;
13use crate::transform::TransformCtx;
14use crate::ullbc_ast::{BlockId, ExprBody, Statement, StatementKind};
15
16use crate::transform::ctx::UllbcPass;
17
18type LocalUses = IndexVec<BlockId, HashSet<LocalId>>;
19
20/// Compute for each block the locals that are assumed to have been initialized/defined before entering it.
21fn compute_uses(body: &ExprBody) -> LocalUses {
22    #[derive(Visitor)]
23    struct UsedLocalsVisitor<'a>(&'a mut HashSet<LocalId>);
24
25    impl VisitBody for UsedLocalsVisitor<'_> {
26        fn visit_place(&mut self, x: &Place) -> ::std::ops::ControlFlow<Self::Break> {
27            if let Some(local_id) = x.as_local() {
28                self.0.insert(local_id);
29            }
30            self.visit_inner(x)
31        }
32    }
33
34    body.body.map_ref(|block| {
35        let mut uses = HashSet::new();
36        let mut visitor = UsedLocalsVisitor(&mut uses);
37
38        // do a simple live variable analysis by walking the block backwards
39        for statement in block.statements.iter().rev() {
40            match &statement.kind {
41                StatementKind::Assign(place, rval) => {
42                    // We clear the assigned place, but it may be re-added
43                    // if it's used in rval
44                    if let Some(local_id) = place.as_local() {
45                        visitor.0.remove(&local_id);
46                    }
47                    let _ = rval.drive_body(&mut visitor);
48                }
49                StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
50                    // A `StorageLive` re-sets the local to be uninitialised,
51                    // so any usage after this point doesn't matter
52                    // Similarly, a `StorageDead` means the local is de-initialised,
53                    // so we can ignore any usage after this point
54                    visitor.0.remove(&local);
55                }
56                _ => {
57                    let _ = statement.drive_body(&mut visitor);
58                }
59            }
60        }
61
62        uses
63    })
64}
65
66/// Whether the value uses the given local in a place.
67fn uses_local<T: BodyVisitable>(x: &T, local: LocalId) -> bool {
68    struct FoundIt;
69    struct UsesLocalVisitor(LocalId);
70
71    impl Visitor for UsesLocalVisitor {
72        type Break = FoundIt;
73    }
74    impl VisitBody for UsesLocalVisitor {
75        fn visit_place(&mut self, x: &Place) -> ::std::ops::ControlFlow<Self::Break> {
76            if let Some(local_id) = x.as_local() {
77                if local_id == self.0 {
78                    return ControlFlow::Break(FoundIt);
79                }
80            }
81            self.visit_inner(x)
82        }
83
84        fn visit_ullbc_statement(
85            &mut self,
86            x: &ullbc_ast::Statement,
87        ) -> ::std::ops::ControlFlow<Self::Break> {
88            match x.kind {
89                StatementKind::StorageDead(_) | StatementKind::StorageLive(_) => {
90                    ControlFlow::Continue(())
91                }
92                _ => self.visit_inner(x),
93            }
94        }
95    }
96
97    x.drive_body(&mut UsesLocalVisitor(local)).is_break()
98}
99
100fn make_binop_overflow_panic<T: BodyVisitable>(
101    x: &mut [T],
102    matches: impl Fn(&BinOp, &Operand, &Operand) -> bool,
103) -> bool {
104    let mut found = false;
105    for y in x.iter_mut() {
106        y.dyn_visit_in_body_mut(|rv: &mut Rvalue| {
107            if let Rvalue::BinaryOp(binop, op_l, op_r) = rv
108                && matches(binop, op_l, op_r)
109            {
110                *binop = binop.with_overflow(OverflowMode::Panic);
111                found = true;
112            }
113        });
114    }
115    found
116}
117
118fn make_unop_overflow_panic<T: BodyVisitable>(
119    x: &mut [T],
120    matches: impl Fn(&UnOp, &Operand) -> bool,
121) -> bool {
122    let mut found = false;
123    for y in x.iter_mut() {
124        y.dyn_visit_in_body_mut(|rv: &mut Rvalue| {
125            if let Rvalue::UnaryOp(unop, op) = rv
126                && matches(unop, op)
127            {
128                *unop = unop.with_overflow(OverflowMode::Panic);
129                found = true;
130            }
131        });
132    }
133    found
134}
135
136/// Check if the two operands are equivalent: either they're the same constant, or they represent
137/// the same place (regardless of whether the operand is a move or a copy)
138fn equiv_op(op_l: &Operand, op_r: &Operand) -> bool {
139    match (op_l, op_r) {
140        (Operand::Copy(l) | Operand::Move(l), Operand::Copy(r) | Operand::Move(r)) => l == r,
141        (Operand::Const(l), Operand::Const(r)) => l == r,
142        _ => false,
143    }
144}
145
146/// Rustc inserts dynamic checks during MIR lowering. They all end in an `Assert` statement (and
147/// this is the only use of this statement).
148fn remove_dynamic_checks(
149    _ctx: &mut TransformCtx,
150    uses: &LocalUses,
151    block_id: BlockId,
152    locals: &mut Locals,
153    statements: &mut [Statement],
154) {
155    // We return the statements we want to keep, which must be a prefix of `block.statements`.
156    let statements_to_keep = match statements {
157        // Bounds checks for slices. They look like:
158        //   l := use(copy a.metadata)
159        //   b := copy x < copy l
160        //   assert(move b == true)
161        [
162            Statement {
163                kind: StatementKind::Assign(len, Rvalue::Use(Operand::Copy(len_op))),
164                ..
165            },
166            Statement {
167                kind:
168                    StatementKind::Assign(
169                        is_in_bounds,
170                        Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
171                    ),
172                ..
173            },
174            Statement {
175                kind:
176                    StatementKind::Assert {
177                        assert:
178                            Assert {
179                                cond: Operand::Move(cond),
180                                expected: true,
181                                ..
182                            },
183                        ..
184                    },
185                ..
186            },
187            rest @ ..,
188        ] if lt_op2 == len
189            && cond == is_in_bounds
190            && let Some((_, ProjectionElem::PtrMetadata)) = len_op.as_projection() =>
191        {
192            rest
193        }
194        // Sometimes that instead looks like:
195        //   a := &raw const *z
196        //   l := use(copy a.metadata)
197        //   b := copy x < copy l
198        //   assert(move b == true)
199        [
200            Statement {
201                kind:
202                    StatementKind::Assign(
203                        reborrow,
204                        Rvalue::RawPtr {
205                            kind: RefKind::Shared,
206                            ..
207                        },
208                    ),
209                ..
210            },
211            Statement {
212                kind: StatementKind::Assign(len, Rvalue::Use(Operand::Copy(len_op))),
213                ..
214            },
215            Statement {
216                kind:
217                    StatementKind::Assign(
218                        is_in_bounds,
219                        Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
220                    ),
221                ..
222            },
223            Statement {
224                kind:
225                    StatementKind::Assert {
226                        assert:
227                            Assert {
228                                cond: Operand::Move(cond),
229                                expected: true,
230                                check_kind: Some(BuiltinAssertKind::BoundsCheck { .. }),
231                            },
232                        ..
233                    },
234                ..
235            },
236            rest @ ..,
237        ] if lt_op2 == len
238            && cond == is_in_bounds
239            && let Some((slice_place, ProjectionElem::PtrMetadata)) = len_op.as_projection()
240            && reborrow == slice_place =>
241        {
242            rest
243        }
244
245        // Zero checks for division and remainder. They look like:
246        //   b := copy y == const 0
247        //   assert(move b == false)
248        //   ...
249        //   res := x {/,%} move y;
250        //   ... or ...
251        //   b := const y == const 0
252        //   assert(move b == false)
253        //   ...
254        //   res := x {/,%} const y;
255        //
256        // This also overlaps with overflow checks for negation, which looks like:
257        //   is_min := x == INT::min
258        //   assert(move is_min == false)
259        //   ...
260        //   res := -x;
261        [
262            Statement {
263                kind:
264                    StatementKind::Assign(
265                        is_zero,
266                        Rvalue::BinaryOp(BinOp::Eq, y_op, Operand::Const(_zero)),
267                    ),
268                ..
269            },
270            Statement {
271                kind:
272                    StatementKind::Assert {
273                        assert:
274                            Assert {
275                                cond: Operand::Move(cond),
276                                expected: false,
277                                check_kind:
278                                    Some(
279                                        BuiltinAssertKind::DivisionByZero(_)
280                                        | BuiltinAssertKind::RemainderByZero(_)
281                                        | BuiltinAssertKind::OverflowNeg(_),
282                                    ),
283                            },
284                        ..
285                    },
286                ..
287            },
288            rest @ ..,
289        ] if cond == is_zero => {
290            let found = make_binop_overflow_panic(rest, |bop, _, r| {
291                matches!(bop, BinOp::Div(_) | BinOp::Rem(_)) && equiv_op(r, y_op)
292            }) || make_unop_overflow_panic(rest, |unop, o| {
293                matches!(unop, UnOp::Neg(_)) && equiv_op(o, y_op)
294            });
295            if found {
296                rest
297            } else {
298                return;
299            }
300        }
301
302        // Overflow checks for signed division and remainder. They look like:
303        //   is_neg_1 := y == (-1)
304        //   is_min := x == INT::min
305        //   has_overflow := move (is_neg_1) & move (is_min)
306        //   assert(move has_overflow == false)
307        // Note here we don't need to update the operand to panic, as this was already done
308        // by the previous pass for division by zero.
309        [
310            Statement {
311                kind: StatementKind::Assign(is_neg_1, Rvalue::BinaryOp(BinOp::Eq, _y_op, _minus_1)),
312                ..
313            },
314            Statement {
315                kind: StatementKind::Assign(is_min, Rvalue::BinaryOp(BinOp::Eq, _x_op, _int_min)),
316                ..
317            },
318            Statement {
319                kind:
320                    StatementKind::Assign(
321                        has_overflow,
322                        Rvalue::BinaryOp(
323                            BinOp::BitAnd,
324                            Operand::Move(and_op1),
325                            Operand::Move(and_op2),
326                        ),
327                    ),
328                ..
329            },
330            Statement {
331                kind:
332                    StatementKind::Assert {
333                        assert:
334                            Assert {
335                                cond: Operand::Move(cond),
336                                expected: false,
337                                check_kind: Some(BuiltinAssertKind::Overflow(..)),
338                            },
339                        ..
340                    },
341                ..
342            },
343            rest @ ..,
344        ] if and_op1 == is_neg_1 && and_op2 == is_min && cond == has_overflow => rest,
345
346        // Overflow checks for right/left shift. They can look like:
347        //   a := y as u32; // or another type
348        //   b := move a < const 32; // or another constant
349        //   assert(move b == true);
350        //   ...
351        //   res := x {<<,>>} y;
352        [
353            Statement {
354                kind: StatementKind::Assign(cast, Rvalue::UnaryOp(UnOp::Cast(_), y_op)),
355                ..
356            },
357            Statement {
358                kind:
359                    StatementKind::Assign(
360                        has_overflow,
361                        Rvalue::BinaryOp(BinOp::Lt, Operand::Move(lhs), Operand::Const(..)),
362                    ),
363                ..
364            },
365            Statement {
366                kind:
367                    StatementKind::Assert {
368                        assert:
369                            Assert {
370                                cond: Operand::Move(cond),
371                                expected: true,
372                                check_kind: Some(BuiltinAssertKind::Overflow(..)),
373                            },
374                        ..
375                    },
376                ..
377            },
378            rest @ ..,
379        ] if cond == has_overflow
380            && lhs == cast
381            && let Some(cast_local) = cast.as_local()
382            && !rest.iter().any(|st| uses_local(st, cast_local)) =>
383        {
384            let found = make_binop_overflow_panic(rest, |bop, _, r| {
385                matches!(bop, BinOp::Shl(_) | BinOp::Shr(_)) && equiv_op(r, y_op)
386            });
387            if found {
388                rest
389            } else {
390                return;
391            }
392        }
393        // or like:
394        //   b := y < const 32; // or another constant
395        //   assert(move b == true);
396        //   ...
397        //   res := x {<<,>>} y;
398        //
399        // this also overlaps with out of bounds checks for arrays, so we check for either;
400        // these look like:
401        //   b := copy y < const _
402        //   assert(move b == true)
403        //   ...
404        //   res := a[y];
405        [
406            Statement {
407                kind:
408                    StatementKind::Assign(
409                        has_overflow,
410                        Rvalue::BinaryOp(BinOp::Lt, y_op, Operand::Const(..)),
411                    ),
412                ..
413            },
414            Statement {
415                kind:
416                    StatementKind::Assert {
417                        assert:
418                            Assert {
419                                cond: Operand::Move(cond),
420                                expected: true,
421                                check_kind:
422                                    Some(
423                                        BuiltinAssertKind::Overflow(..)
424                                        | BuiltinAssertKind::BoundsCheck { .. },
425                                    ),
426                            },
427                        ..
428                    },
429                ..
430            },
431            rest @ ..,
432        ] if cond == has_overflow => {
433            // look for a shift operation
434            let mut found = make_binop_overflow_panic(rest, |bop, _, r| {
435                matches!(bop, BinOp::Shl(_) | BinOp::Shr(_)) && equiv_op(r, y_op)
436            });
437            if !found {
438                // otherwise, look for an array access
439                for stmt in rest.iter_mut() {
440                    stmt.dyn_visit_in_body(|p: &Place| {
441                        if let Some((_, ProjectionElem::Index { offset, .. })) = p.as_projection()
442                            && equiv_op(offset, y_op)
443                        {
444                            found = true;
445                        }
446                    });
447                }
448            }
449
450            if found {
451                rest
452            } else {
453                return;
454            }
455        }
456
457        // Overflow checks for addition/subtraction/multiplication. They look like:
458        // ```text
459        //   r := x checked.+ y;
460        //   assert(move r.1 == false);
461        //   ...
462        //   z := move r.0;
463        // ```
464        // We replace that with:
465        // ```text
466        // z := x + y;
467        // ```
468        //
469        // But sometimes, because of constant promotion, we end up with a lone checked operation
470        // without assert. In that case we replace it with its wrapping equivalent.
471        [
472            Statement {
473                kind:
474                    StatementKind::Assign(
475                        tuple,
476                        Rvalue::BinaryOp(
477                            binop @ (BinOp::AddChecked | BinOp::SubChecked | BinOp::MulChecked),
478                            _,
479                            _,
480                        ),
481                    ),
482                ..
483            },
484            rest @ ..,
485        ] if let Some(tuple_local_id) = tuple.as_local() => {
486            // Check if the result boolean is used in any other way than just getting the integer
487            // result.
488            let mut uses_of_tuple = 0;
489            let mut uses_of_integer = 0;
490            if *tuple == locals.return_place() {
491                uses_of_tuple += 1; // The return place counts as a use.
492            }
493            for stmt in rest.iter_mut() {
494                stmt.dyn_visit_in_body(|p: &Place| {
495                    if p == tuple {
496                        uses_of_tuple += 1;
497                    }
498                    if let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
499                        p.as_projection()
500                        && fid.index() == 0
501                        && sub == tuple
502                    {
503                        uses_of_integer += 1;
504                    }
505                });
506            }
507            // Check if the operation is followed by an assert.
508            let followed_by_assert = if let [
509                Statement {
510                    kind:
511                        StatementKind::Assert {
512                            assert:
513                                Assert {
514                                    cond: Operand::Move(assert_cond),
515                                    expected: false,
516                                    check_kind: Some(BuiltinAssertKind::Overflow(..)),
517                                },
518                            ..
519                        },
520                    ..
521                },
522                ..,
523            ] = rest
524                && let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
525                    assert_cond.as_projection()
526                && fid.index() == 1
527                && sub == tuple
528            {
529                true
530            } else {
531                false
532            };
533            if uses_of_tuple != uses_of_integer && !followed_by_assert {
534                // The tuple is used either directly or for the overflow check; we change nothing.
535                return;
536            }
537
538            if followed_by_assert {
539                // We have a compiler-emitted assert. We replace the operation with one that has
540                // panic-on-overflow semantics.
541                *binop = binop.with_overflow(OverflowMode::Panic);
542                // The failure behavior is part of the binop now, so we remove the assert.
543                rest[0].kind = StatementKind::Nop;
544            } else {
545                // The tuple is used exclusively to access the integer result, so we replace the
546                // operation with wrapping semantics.
547                *binop = binop.with_overflow(OverflowMode::Wrap);
548            }
549            // Fixup the local type.
550            let result_local = &mut locals.locals[tuple_local_id];
551            result_local.ty = result_local.ty.as_tuple().unwrap()[0].clone();
552            // Fixup the place type.
553            let new_result_place = locals.place_for_var(tuple_local_id);
554            // Replace uses of `r.0` with `r`.
555            for stmt in rest.iter_mut() {
556                stmt.dyn_visit_in_body_mut(|p: &mut Place| {
557                    if let Some((sub, ProjectionElem::Field(FieldProjKind::Tuple(..), fid))) =
558                        p.as_projection()
559                        && sub == tuple
560                    {
561                        assert_eq!(fid.index(), 0);
562                        *p = new_result_place.clone()
563                    }
564                });
565            }
566            *tuple = new_result_place;
567            return;
568        }
569
570        _ => return,
571    };
572
573    // Remove the statements we're not keeping.
574    let keep_len = statements_to_keep.len();
575    let removed_len = statements.len() - keep_len;
576    for i in 0..removed_len {
577        // If the statement we're removing assigns to a local that
578        // is used elsewhere (in the leftover statements or in another block),
579        // we don't remove it.
580        if let StatementKind::Assign(place, _) = &statements[i].kind
581            && let Some(local) = place.as_local()
582            && let mut statements_to_keep = statements[removed_len..].as_ref().iter()
583            && let mut other_blocks = uses.iter_indexed().filter(|(bid, _)| *bid != block_id)
584            && (other_blocks.any(|(_, used)| used.contains(&local))
585                || statements_to_keep.any(|st| uses_local(st, local)))
586        {
587            continue;
588        };
589        statements[i].kind = StatementKind::Nop;
590    }
591}
592
593pub struct Transform;
594impl UllbcPass for Transform {
595    fn should_run(&self, options: &crate::options::TranslateOptions) -> bool {
596        options.reconstruct_fallible_operations
597    }
598
599    fn transform_body(&self, ctx: &mut TransformCtx, b: &mut ExprBody) {
600        let local_uses: LocalUses = compute_uses(b);
601        b.transform_sequences_fwd(|id, locals, seq| {
602            remove_dynamic_checks(ctx, &local_uses, id, locals, seq);
603            Vec::new()
604        });
605    }
606}