1use std::mem;
8
9use derive_generic_visitor::*;
10
11use crate::transform::TransformCtx;
12use crate::ullbc_ast::{ExprBody, RawStatement, Statement};
13use crate::{ast::*, register_error};
14
15use super::ctx::UllbcPass;
16
17fn uses_local<T: BodyVisitable>(x: &T, local: LocalId) -> bool {
19 struct FoundIt;
20 struct UsesLocalVisitor(LocalId);
21
22 impl Visitor for UsesLocalVisitor {
23 type Break = FoundIt;
24 }
25 impl VisitBody for UsesLocalVisitor {
26 fn visit_place(&mut self, x: &Place) -> ::std::ops::ControlFlow<Self::Break> {
27 if let Some(local_id) = x.as_local() {
28 if local_id == self.0 {
29 return ControlFlow::Break(FoundIt);
30 }
31 }
32 self.visit_inner(x)
33 }
34 }
35
36 x.drive_body(&mut UsesLocalVisitor(local)).is_break()
37}
38
39fn remove_dynamic_checks(
42 ctx: &mut TransformCtx,
43 _locals: &mut Locals,
44 statements: &mut [Statement],
45) {
46 let statements_to_keep = match statements {
48 [Statement {
53 content:
54 RawStatement::Assign(len, Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Copy(len_op))),
55 ..
56 }, Statement {
57 content:
58 RawStatement::Assign(
59 is_in_bounds,
60 Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
61 ),
62 ..
63 }, Statement {
64 content:
65 RawStatement::Assert(Assert {
66 cond: Operand::Move(cond),
67 expected: true,
68 ..
69 }),
70 ..
71 }, rest @ ..]
72 if lt_op2 == len && cond == is_in_bounds && len_op.ty().is_ref() =>
73 {
74 rest
75 }
76 [Statement {
82 content: RawStatement::Assign(reborrow, Rvalue::RawPtr(_, RefKind::Shared)),
83 ..
84 }, Statement {
85 content:
86 RawStatement::Assign(len, Rvalue::UnaryOp(UnOp::PtrMetadata, Operand::Move(len_op))),
87 ..
88 }, Statement {
89 content:
90 RawStatement::Assign(
91 is_in_bounds,
92 Rvalue::BinaryOp(BinOp::Lt, _, Operand::Copy(lt_op2)),
93 ),
94 ..
95 }, Statement {
96 content:
97 RawStatement::Assert(Assert {
98 cond: Operand::Move(cond),
99 expected: true,
100 ..
101 }),
102 ..
103 }, rest @ ..]
104 if reborrow == len_op && lt_op2 == len && cond == is_in_bounds =>
105 {
106 rest
107 }
108
109 [Statement {
113 content:
114 RawStatement::Assign(is_in_bounds, Rvalue::BinaryOp(BinOp::Lt, _, Operand::Const(_))),
115 ..
116 }, Statement {
117 content:
118 RawStatement::Assert(Assert {
119 cond: Operand::Move(cond),
120 expected: true,
121 ..
122 }),
123 ..
124 }, rest @ ..]
125 if cond == is_in_bounds =>
126 {
127 rest
128 }
129
130 [Statement {
134 content:
135 RawStatement::Assign(is_zero, Rvalue::BinaryOp(BinOp::Eq, _, Operand::Const(_zero))),
136 ..
137 }, Statement {
138 content:
139 RawStatement::Assert(Assert {
140 cond: Operand::Move(cond),
141 expected: false,
142 ..
143 }),
144 ..
145 }, rest @ ..]
146 if cond == is_zero =>
147 {
148 rest
149 }
150
151 [Statement {
157 content: RawStatement::Assign(is_neg_1, Rvalue::BinaryOp(BinOp::Eq, _y_op, _minus_1)),
158 ..
159 }, Statement {
160 content: RawStatement::Assign(is_min, Rvalue::BinaryOp(BinOp::Eq, _x_op, _int_min)),
161 ..
162 }, Statement {
163 content:
164 RawStatement::Assign(
165 has_overflow,
166 Rvalue::BinaryOp(BinOp::BitAnd, Operand::Move(and_op1), Operand::Move(and_op2)),
167 ),
168 ..
169 }, Statement {
170 content:
171 RawStatement::Assert(Assert {
172 cond: Operand::Move(cond),
173 expected: false,
174 ..
175 }),
176 ..
177 }, rest @ ..]
178 if and_op1 == is_neg_1 && and_op2 == is_min && cond == has_overflow =>
179 {
180 rest
181 }
182
183 [Statement {
188 content: RawStatement::Assign(cast, Rvalue::UnaryOp(UnOp::Cast(_), _)),
189 ..
190 }, Statement {
191 content:
192 RawStatement::Assign(
193 has_overflow,
194 Rvalue::BinaryOp(BinOp::Lt, Operand::Move(lhs), Operand::Const(..)),
195 ),
196 ..
197 }, Statement {
198 content:
199 RawStatement::Assert(Assert {
200 cond: Operand::Move(cond),
201 expected: true,
202 ..
203 }),
204 ..
205 }, rest @ ..]
206 if cond == has_overflow
207 && lhs == cast
208 && let Some(cast_local) = cast.as_local()
209 && !rest.iter().any(|st| uses_local(st, cast_local)) =>
210 {
211 rest
212 }
213 [Statement {
217 content:
218 RawStatement::Assign(has_overflow, Rvalue::BinaryOp(BinOp::Lt, _, Operand::Const(..))),
219 ..
220 }, Statement {
221 content:
222 RawStatement::Assert(Assert {
223 cond: Operand::Move(cond),
224 expected: true,
225 ..
226 }),
227 ..
228 }, rest @ ..]
229 if cond == has_overflow =>
230 {
231 rest
232 }
233
234 [Statement {
246 content:
247 RawStatement::Assign(
248 result,
249 rval_op @ Rvalue::BinaryOp(
250 BinOp::CheckedAdd | BinOp::CheckedSub | BinOp::CheckedMul,
251 _,
252 _,
253 ),
254 ),
255 ..
256 }, Statement {
257 content:
258 RawStatement::Assert(Assert {
259 cond: Operand::Move(assert_cond),
260 expected: false,
261 ..
262 }),
263 ..
264 }, rest @ ..]
265 if let Some((sub1, ProjectionElem::Field(FieldProjKind::Tuple(..), fid1))) =
266 assert_cond.as_projection()
267 && fid1.index() == 1
268 && result.is_local()
269 && sub1 == result =>
270 {
271 let Rvalue::BinaryOp(binop, ..) = rval_op else {
273 unreachable!()
274 };
275 *binop = match binop {
276 BinOp::CheckedAdd => BinOp::Add,
277 BinOp::CheckedSub => BinOp::Sub,
278 BinOp::CheckedMul => BinOp::Mul,
279 _ => unreachable!(),
280 };
281
282 let mut found_use = false;
285 for st in rest.iter_mut() {
286 if let RawStatement::Assign(_, Rvalue::Use(Operand::Move(assigned))) =
287 &mut st.content
288 && let Some((sub0, ProjectionElem::Field(FieldProjKind::Tuple(..), fid0))) =
289 assigned.as_projection()
290 && fid0.index() == 0
291 && sub0 == result
292 {
293 if found_use {
294 register_error!(
295 ctx,
296 st.span,
297 "Double use of a checked binary operation; \
298 the MIR is not in the shape we expected."
299 );
300 }
301 found_use = true;
302 let RawStatement::Assign(_, rval) = &mut st.content else {
303 unreachable!()
304 };
305 mem::swap(rval_op, rval);
307 }
308 }
309 rest
311 }
312
313 _ => return,
314 };
315
316 let keep_len = statements_to_keep.len();
318 for i in 0..statements.len() - keep_len {
319 statements[i].content = RawStatement::Nop;
320 }
321}
322
323pub struct Transform;
324impl UllbcPass for Transform {
325 fn transform_body(&self, ctx: &mut TransformCtx, b: &mut ExprBody) {
326 b.transform_sequences_fwd(|locals, seq| {
327 remove_dynamic_checks(ctx, locals, seq);
328 Vec::new()
329 });
330 }
331}