1use itertools::Itertools;
23use std::collections::HashSet;
24
25use crate::ast::ullbc_ast_utils::StmtLoc;
26use crate::name_matcher::NamePattern;
27use crate::transform::ctx::UllbcPass;
28use crate::transform::{CowBox, TransformCtx};
29use crate::ullbc_ast::*;
30
31pub struct Transform {
32 box_write: Option<FunDeclId>,
33}
34
35struct Rewrite {
36 new_uninit_bid: BlockId,
37 new_uninit_target: BlockId,
38 drop_bid: BlockId,
39 target_bid: BlockId,
40 payload_loc: StmtLoc,
41 move_loc: StmtLoc,
42 arg_move_loc: StmtLoc,
43 span: Span,
44 payload_elems: Vec<Operand>,
45 elem_ty: Ty,
46 len: Box<ConstantExpr>,
47 uninit_box: Place,
48 branched_before_payload: bool,
49 box_array: Place,
50 box_array_generics: GenericArgs,
51 assume_init_generics: GenericArgs,
52 drop_on_unwind: BlockId,
53 assume_init_target: BlockId,
54}
55
56struct PayloadAssign {
57 loc: StmtLoc,
58 span: Span,
59 payload_elems: Vec<Operand>,
60 elem_ty: Ty,
61 len: Box<ConstantExpr>,
62 branched_before_payload: bool,
63}
64
65struct AssumeInitTail {
66 move_loc: StmtLoc,
67 drop_bid: BlockId,
68 target_bid: BlockId,
69 arg_move_loc: StmtLoc,
70 box_array: Place,
71 assume_init_generics: GenericArgs,
72 drop_on_unwind: BlockId,
73 assume_init_target: BlockId,
74}
75
76fn assume_init_fn_ptr<'a>(ctx: &TransformCtx, call: &'a Call) -> Option<&'a FnPtr> {
77 if let FnOperand::Regular(fn_ptr) = &call.func
78 && let FnPtrKind::Fun(FunId::Regular(fid)) = *fn_ptr.kind
79 && ctx.translated.item_name(fid).short_str() == Some("assume_init")
80 {
81 Some(fn_ptr)
82 } else {
83 None
84 }
85}
86
87fn box_inner(ty: &Ty) -> Option<Ty> {
88 let TyKind::Adt(TypeDeclRef {
89 id: TypeId::Builtin(BuiltinTy::Box),
90 generics,
91 }) = ty.kind()
92 else {
93 return None;
94 };
95 Some(generics.types[TypeVarId::from_usize(0)].clone())
96}
97
98fn box_generics(ty: &Ty) -> Option<GenericArgs> {
99 let TyKind::Adt(TypeDeclRef {
100 id: TypeId::Builtin(BuiltinTy::Box),
101 generics,
102 }) = ty.kind()
103 else {
104 return None;
105 };
106 Some((**generics).clone())
107}
108
109fn find_array_assign(body: &ExprBody, start: BlockId, src_local: LocalId) -> Option<PayloadAssign> {
116 let mut out = None;
117 for (bid, block) in body.body.iter_enumerated() {
118 for (idx, st) in block.statements.iter().enumerate() {
119 let Some((place, Rvalue::Aggregate(AggregateKind::Array(elem_ty, len), elems))) =
120 st.kind.as_assign()
121 else {
122 continue;
123 };
124 if place.local_id() != Some(src_local) {
125 continue;
126 }
127 if out.is_some() {
128 return None;
129 }
130
131 let loc = StmtLoc::new(bid, idx);
132 out = Some(PayloadAssign {
133 loc,
134 span: st.span,
135 payload_elems: elems.clone(),
136 elem_ty: elem_ty.clone(),
137 len: len.clone(),
138 branched_before_payload: branched_before(body, start, loc.block)?,
139 });
140 }
141 }
142 out
143}
144
145fn branched_before(body: &ExprBody, start: BlockId, target: BlockId) -> Option<bool> {
146 let mut block_id = start;
147 let mut visited = HashSet::new();
148 while visited.insert(block_id) {
149 if block_id == target {
150 return Some(false);
151 }
152
153 let block = &body.body[block_id];
154 let targets = block.targets_ignoring_unwind();
155 if targets.len() > 1 {
156 return Some(true);
157 }
158 block_id = targets.into_iter().exactly_one().ok()?;
159 }
160 None
161}
162
163fn unique_target(term: &Terminator) -> Option<BlockId> {
164 term.targets_ignoring_unwind()
165 .into_iter()
166 .exactly_one()
167 .ok()
168}
169
170fn find_next_move_of(
171 body: &ExprBody,
172 mut cursor: StmtLoc,
173 src: &Place,
174) -> Option<(StmtLoc, Place)> {
175 let mut visited = HashSet::new();
176 while visited.insert(cursor) {
177 let block = &body.body[cursor.block];
178 while cursor.statement < block.statements.len() {
179 let st = &body[cursor];
180 if let Some((dst_place, Rvalue::Use(Operand::Move(src_place), _))) = st.kind.as_assign()
181 && src_place == src
182 {
183 return Some((cursor, dst_place.clone()));
184 }
185 cursor = cursor.after();
186 }
187 cursor = StmtLoc::block_start(unique_target(&block.terminator)?);
188 }
189 None
190}
191
192fn find_next_drop(
193 body: &ExprBody,
194 mut block_id: BlockId,
195 dropped_place: &Place,
196) -> Option<(BlockId, BlockId, BlockId)> {
197 let mut visited = HashSet::new();
198 while visited.insert(block_id) {
199 let block = &body.body[block_id];
200 if let TerminatorKind::Drop {
201 place: dropped,
202 target,
203 on_unwind,
204 ..
205 } = &block.terminator.kind
206 && dropped == dropped_place
207 {
208 return Some((block_id, *target, *on_unwind));
209 }
210 block_id = unique_target(&block.terminator)?;
211 }
212 None
213}
214
215fn find_move_in_block(
216 body: &ExprBody,
217 block: BlockId,
218 src: &Place,
219 dst: &Place,
220) -> Option<StmtLoc> {
221 let mut out = None;
222 for (statement, st) in body.body[block].statements.iter().enumerate() {
223 if let Some((dst_place, Rvalue::Use(Operand::Move(src_place), _))) = st.kind.as_assign()
224 && dst_place == dst
225 && src_place == src
226 {
227 if out.is_some() {
228 return None;
229 }
230 out = Some(StmtLoc::new(block, statement));
231 }
232 }
233 out
234}
235
236fn find_assume_init_tail(
237 ctx: &TransformCtx,
238 body: &ExprBody,
239 cursor: StmtLoc,
240 uninit_box: &Place,
241) -> Option<AssumeInitTail> {
242 let (move_loc, moved_box) = find_next_move_of(body, cursor, uninit_box)?;
243 let (drop_bid, target_bid, drop_on_unwind) = find_next_drop(body, move_loc.block, uninit_box)?;
244
245 let target_block = &body.body[target_bid];
246 let (call, assume_init_target, _unwind) = target_block.terminator.kind.as_call()?;
247 let assume_init_fn = assume_init_fn_ptr(ctx, call)?;
248 let [Operand::Move(init_box)] = call.args.as_slice() else {
249 return None;
250 };
251 let arg_move_loc = find_move_in_block(body, target_bid, &moved_box, init_box)?;
252
253 Some(AssumeInitTail {
254 move_loc,
255 drop_bid,
256 target_bid,
257 arg_move_loc,
258 box_array: call.dest.clone(),
259 assume_init_generics: assume_init_fn.generics.as_ref().clone(),
260 drop_on_unwind,
261 assume_init_target: *assume_init_target,
262 })
263}
264
265fn is_new_uninit_call(ctx: &TransformCtx, call: &Call) -> bool {
266 if !call.args.is_empty() {
267 return false;
268 }
269
270 let FnOperand::Regular(fn_ptr) = &call.func else {
271 return false;
272 };
273 let FnPtrKind::Fun(FunId::Regular(fid)) = *fn_ptr.kind else {
274 return false;
275 };
276 ctx.translated.item_name(fid).short_str() == Some("new_uninit")
277}
278
279impl Transform {
280 pub fn new(ctx: &mut TransformCtx) -> CowBox<dyn UllbcPass> {
281 let pat = NamePattern::parse(crate::builtins::BOX_WRITE_PATTERN).unwrap();
282 let box_write = ctx
283 .translated
284 .item_names
285 .iter()
286 .filter(|(_, name)| pat.matches(&ctx.translated, name))
287 .filter_map(|(id, _)| id.as_fun())
288 .copied()
289 .exactly_one()
290 .ok();
291 CowBox::Owned(Box::new(Transform { box_write }))
292 }
293}
294
295impl UllbcPass for Transform {
296 fn should_run(&self, options: &crate::options::TranslateOptions) -> bool {
297 options.treat_box_as_builtin && !options.monomorphize_with_hax && self.box_write.is_some()
298 }
299
300 fn transform_body(&self, ctx: &mut TransformCtx, body: &mut ExprBody) {
301 let box_write = self.box_write.unwrap();
303
304 let rewrites = body
314 .body
315 .iter_enumerated()
316 .filter_map(|(new_uninit_bid, block)| {
317 let TerminatorKind::Call {
318 call,
319 target: new_uninit_target,
320 ..
321 } = &block.terminator.kind
322 else {
323 return None;
324 };
325
326 if !is_new_uninit_call(ctx, call) {
327 return None;
328 }
329 let uninit_box = call.dest.clone();
331 let maybe_uninit_array_ty = box_inner(uninit_box.ty())?;
332 let mu_decl = &ctx.translated.type_decls[maybe_uninit_array_ty.as_adt_id()?];
333 if mu_decl.item_meta.lang_item.as_deref() != Some("maybe_uninit") {
334 return None;
335 };
336 let uninit_box_l = uninit_box.local_id()?;
337
338 let payload = find_array_assign(body, *new_uninit_target, uninit_box_l)?;
340
341 let tail = find_assume_init_tail(ctx, body, payload.loc.after(), &uninit_box)?;
343 let box_array_generics = box_generics(tail.box_array.ty())?;
344
345 Some(Rewrite {
346 new_uninit_bid,
347 new_uninit_target: *new_uninit_target,
348 drop_bid: tail.drop_bid,
349 target_bid: tail.target_bid,
350 payload_loc: payload.loc,
351 move_loc: tail.move_loc,
352 arg_move_loc: tail.arg_move_loc,
353 span: payload.span,
354 payload_elems: payload.payload_elems,
355 elem_ty: payload.elem_ty,
356 len: payload.len,
357 uninit_box,
358 branched_before_payload: payload.branched_before_payload,
359 box_array_generics,
360 assume_init_generics: tail.assume_init_generics,
361 drop_on_unwind: tail.drop_on_unwind,
362 box_array: tail.box_array,
363 assume_init_target: tail.assume_init_target,
364 })
365 });
366
367 for rw in rewrites.collect::<Vec<_>>() {
368 let array_ty = Ty::mk_array(rw.elem_ty.clone(), *rw.len.clone());
369 let array_local = body.locals.new_var(None, array_ty.clone());
370 let box_array_ty = rw.box_array.ty().clone();
371 let box_array_local = body.locals.new_var(None, box_array_ty.clone());
372
373 let array_lid = array_local.as_local().unwrap();
374 let box_array_lid = box_array_local.as_local().unwrap();
375
376 body[rw.move_loc].kind = StatementKind::Nop;
377 body[rw.arg_move_loc].kind = StatementKind::Nop;
378
379 body.body[rw.payload_loc.block].statements.splice(
380 rw.payload_loc.statement..=rw.payload_loc.statement,
381 [
382 StatementKind::StorageLive(array_lid),
383 StatementKind::Assign(
384 array_local.clone(),
385 Rvalue::Aggregate(
386 AggregateKind::Array(rw.elem_ty.clone(), rw.len.clone()),
387 rw.payload_elems,
388 ),
389 ),
390 ]
391 .map(|k| Statement::new(rw.span, k)),
392 );
393
394 let (fn_ptr, args) = if rw.branched_before_payload {
395 (
396 FnPtr::new(
397 FnPtrKind::Fun(FunId::Regular(box_write)),
398 rw.assume_init_generics,
399 ),
400 vec![
401 Operand::Move(rw.uninit_box),
402 Operand::Move(array_local.clone()),
403 ],
404 )
405 } else {
406 body.body[rw.new_uninit_bid].terminator.kind = TerminatorKind::Goto {
407 target: rw.new_uninit_target,
408 };
409 (
410 FnPtr::new(
411 FnPtrKind::Fun(FunId::Builtin(BuiltinFunId::BoxNew)),
412 rw.box_array_generics,
413 ),
414 vec![Operand::Move(array_local.clone())],
415 )
416 };
417
418 let drop_block = &mut body.body[rw.drop_bid];
419 drop_block.statements.push(Statement::new(
420 rw.span,
421 StatementKind::StorageLive(box_array_lid),
422 ));
423 drop_block.terminator.kind = TerminatorKind::Call {
424 call: Call {
425 func: FnOperand::Regular(fn_ptr),
426 args,
427 dest: box_array_local.clone(),
428 },
429 target: rw.target_bid,
430 on_unwind: rw.drop_on_unwind,
431 };
432
433 let target_block = &mut body.body[rw.target_bid];
434 target_block.statements.push(Statement::new(
435 rw.span,
436 StatementKind::StorageDead(array_lid),
437 ));
438 target_block.statements.push(Statement::new(
439 rw.span,
440 StatementKind::Assign(
441 rw.box_array,
442 Rvalue::Use(Operand::Move(box_array_local), WithRetag::No),
443 ),
444 ));
445 target_block.statements.push(Statement::new(
446 rw.span,
447 StatementKind::StorageDead(box_array_lid),
448 ));
449 target_block.terminator.kind = TerminatorKind::Goto {
450 target: rw.assume_init_target,
451 };
452 }
453 }
454}