Skip to main content

charon_lib/transform/resugar/
reconstruct_vec_boxes.rs

1//! Reconstruct rustc's `vec![..]` lowering based on `Box<MaybeUninit<[T; N]>>`.
2//!
3//! In `inline_selected_functions`, we inline the special `box_assume_init_into_vec_unsafe`
4//! function. After that, a `vec![elems...]` expression ends up looking something like:
5//! ```ignore
6//! let mut box = Box::new_uninit::<[T; N]>();
7//! (((*box).1).0).0 = [elems...];
8//! let box = Box::assume_init(box);
9//! ..
10//! ```
11//! The split between assignment and `assume_init` is for performance. The `assume_init` call is
12//! unsafe, so we rewrite it to use `Box::write` instead, and even `Box::new` if possible.
13//! ```ignore
14//! let box_uninit = Box::new_uninit::<[T; N]>();
15//! let arr = [elems...];
16//! let box = Box::write::<[T; N]>(box_uninit, arr);
17//! ..
18//! ```
19//!
20//! See also: <https://github.com/rust-lang/rust/pull/148190>
21
22use itertools::Itertools;
23use std::collections::HashSet;
24
25use crate::ast::ullbc_ast_utils::StmtLoc;
26use crate::name_matcher::NamePattern;
27use crate::transform::ctx::UllbcPass;
28use crate::transform::{CowBox, TransformCtx};
29use crate::ullbc_ast::*;
30
31pub struct Transform {
32    box_write: Option<FunDeclId>,
33}
34
35struct Rewrite {
36    new_uninit_bid: BlockId,
37    new_uninit_target: BlockId,
38    drop_bid: BlockId,
39    target_bid: BlockId,
40    payload_loc: StmtLoc,
41    move_loc: StmtLoc,
42    arg_move_loc: StmtLoc,
43    span: Span,
44    payload_elems: Vec<Operand>,
45    elem_ty: Ty,
46    len: Box<ConstantExpr>,
47    uninit_box: Place,
48    branched_before_payload: bool,
49    box_array: Place,
50    box_array_generics: GenericArgs,
51    assume_init_generics: GenericArgs,
52    drop_on_unwind: BlockId,
53    assume_init_target: BlockId,
54}
55
56struct PayloadAssign {
57    loc: StmtLoc,
58    span: Span,
59    payload_elems: Vec<Operand>,
60    elem_ty: Ty,
61    len: Box<ConstantExpr>,
62    branched_before_payload: bool,
63}
64
65struct AssumeInitTail {
66    move_loc: StmtLoc,
67    drop_bid: BlockId,
68    target_bid: BlockId,
69    arg_move_loc: StmtLoc,
70    box_array: Place,
71    assume_init_generics: GenericArgs,
72    drop_on_unwind: BlockId,
73    assume_init_target: BlockId,
74}
75
76fn assume_init_fn_ptr<'a>(ctx: &TransformCtx, call: &'a Call) -> Option<&'a FnPtr> {
77    if let FnOperand::Regular(fn_ptr) = &call.func
78        && let FnPtrKind::Fun(FunId::Regular(fid)) = *fn_ptr.kind
79        && ctx.translated.item_name(fid).short_str() == Some("assume_init")
80    {
81        Some(fn_ptr)
82    } else {
83        None
84    }
85}
86
87fn box_inner(ty: &Ty) -> Option<Ty> {
88    let TyKind::Adt(TypeDeclRef {
89        id: TypeId::Builtin(BuiltinTy::Box),
90        generics,
91    }) = ty.kind()
92    else {
93        return None;
94    };
95    Some(generics.types[TypeVarId::from_usize(0)].clone())
96}
97
98fn box_generics(ty: &Ty) -> Option<GenericArgs> {
99    let TyKind::Adt(TypeDeclRef {
100        id: TypeId::Builtin(BuiltinTy::Box),
101        generics,
102    }) = ty.kind()
103    else {
104        return None;
105    };
106    Some((**generics).clone())
107}
108
109/// Given `src`, find the unique statement of the form `src = [elems...]`
110/// where the rvalue is an array aggregate.
111///
112/// Also returns whether a straight-line path from `start` hits a branch before reaching the
113/// assignment. We ignore unwind edges; this matches the later rewrite's ability to erase the
114/// allocation when the normal path to initialization is linear.
115fn find_array_assign(body: &ExprBody, start: BlockId, src_local: LocalId) -> Option<PayloadAssign> {
116    let mut out = None;
117    for (bid, block) in body.body.iter_enumerated() {
118        for (idx, st) in block.statements.iter().enumerate() {
119            let Some((place, Rvalue::Aggregate(AggregateKind::Array(elem_ty, len), elems))) =
120                st.kind.as_assign()
121            else {
122                continue;
123            };
124            if place.local_id() != Some(src_local) {
125                continue;
126            }
127            if out.is_some() {
128                return None;
129            }
130
131            let loc = StmtLoc::new(bid, idx);
132            out = Some(PayloadAssign {
133                loc,
134                span: st.span,
135                payload_elems: elems.clone(),
136                elem_ty: elem_ty.clone(),
137                len: len.clone(),
138                branched_before_payload: branched_before(body, start, loc.block)?,
139            });
140        }
141    }
142    out
143}
144
145fn branched_before(body: &ExprBody, start: BlockId, target: BlockId) -> Option<bool> {
146    let mut block_id = start;
147    let mut visited = HashSet::new();
148    while visited.insert(block_id) {
149        if block_id == target {
150            return Some(false);
151        }
152
153        let block = &body.body[block_id];
154        let targets = block.targets_ignoring_unwind();
155        if targets.len() > 1 {
156            return Some(true);
157        }
158        block_id = targets.into_iter().exactly_one().ok()?;
159    }
160    None
161}
162
163fn unique_target(term: &Terminator) -> Option<BlockId> {
164    term.targets_ignoring_unwind()
165        .into_iter()
166        .exactly_one()
167        .ok()
168}
169
170fn find_next_move_of(
171    body: &ExprBody,
172    mut cursor: StmtLoc,
173    src: &Place,
174) -> Option<(StmtLoc, Place)> {
175    let mut visited = HashSet::new();
176    while visited.insert(cursor) {
177        let block = &body.body[cursor.block];
178        while cursor.statement < block.statements.len() {
179            let st = &body[cursor];
180            if let Some((dst_place, Rvalue::Use(Operand::Move(src_place), _))) = st.kind.as_assign()
181                && src_place == src
182            {
183                return Some((cursor, dst_place.clone()));
184            }
185            cursor = cursor.after();
186        }
187        cursor = StmtLoc::block_start(unique_target(&block.terminator)?);
188    }
189    None
190}
191
192fn find_next_drop(
193    body: &ExprBody,
194    mut block_id: BlockId,
195    dropped_place: &Place,
196) -> Option<(BlockId, BlockId, BlockId)> {
197    let mut visited = HashSet::new();
198    while visited.insert(block_id) {
199        let block = &body.body[block_id];
200        if let TerminatorKind::Drop {
201            place: dropped,
202            target,
203            on_unwind,
204            ..
205        } = &block.terminator.kind
206            && dropped == dropped_place
207        {
208            return Some((block_id, *target, *on_unwind));
209        }
210        block_id = unique_target(&block.terminator)?;
211    }
212    None
213}
214
215fn find_move_in_block(
216    body: &ExprBody,
217    block: BlockId,
218    src: &Place,
219    dst: &Place,
220) -> Option<StmtLoc> {
221    let mut out = None;
222    for (statement, st) in body.body[block].statements.iter().enumerate() {
223        if let Some((dst_place, Rvalue::Use(Operand::Move(src_place), _))) = st.kind.as_assign()
224            && dst_place == dst
225            && src_place == src
226        {
227            if out.is_some() {
228                return None;
229            }
230            out = Some(StmtLoc::new(block, statement));
231        }
232    }
233    out
234}
235
236fn find_assume_init_tail(
237    ctx: &TransformCtx,
238    body: &ExprBody,
239    cursor: StmtLoc,
240    uninit_box: &Place,
241) -> Option<AssumeInitTail> {
242    let (move_loc, moved_box) = find_next_move_of(body, cursor, uninit_box)?;
243    let (drop_bid, target_bid, drop_on_unwind) = find_next_drop(body, move_loc.block, uninit_box)?;
244
245    let target_block = &body.body[target_bid];
246    let (call, assume_init_target, _unwind) = target_block.terminator.kind.as_call()?;
247    let assume_init_fn = assume_init_fn_ptr(ctx, call)?;
248    let [Operand::Move(init_box)] = call.args.as_slice() else {
249        return None;
250    };
251    let arg_move_loc = find_move_in_block(body, target_bid, &moved_box, init_box)?;
252
253    Some(AssumeInitTail {
254        move_loc,
255        drop_bid,
256        target_bid,
257        arg_move_loc,
258        box_array: call.dest.clone(),
259        assume_init_generics: assume_init_fn.generics.as_ref().clone(),
260        drop_on_unwind,
261        assume_init_target: *assume_init_target,
262    })
263}
264
265fn is_new_uninit_call(ctx: &TransformCtx, call: &Call) -> bool {
266    if !call.args.is_empty() {
267        return false;
268    }
269
270    let FnOperand::Regular(fn_ptr) = &call.func else {
271        return false;
272    };
273    let FnPtrKind::Fun(FunId::Regular(fid)) = *fn_ptr.kind else {
274        return false;
275    };
276    ctx.translated.item_name(fid).short_str() == Some("new_uninit")
277}
278
279impl Transform {
280    pub fn new(ctx: &mut TransformCtx) -> CowBox<dyn UllbcPass> {
281        let pat = NamePattern::parse(crate::builtins::BOX_WRITE_PATTERN).unwrap();
282        let box_write = ctx
283            .translated
284            .item_names
285            .iter()
286            .filter(|(_, name)| pat.matches(&ctx.translated, name))
287            .filter_map(|(id, _)| id.as_fun())
288            .copied()
289            .exactly_one()
290            .ok();
291        CowBox::Owned(Box::new(Transform { box_write }))
292    }
293}
294
295impl UllbcPass for Transform {
296    fn should_run(&self, options: &crate::options::TranslateOptions) -> bool {
297        options.treat_box_as_builtin && !options.monomorphize_with_hax && self.box_write.is_some()
298    }
299
300    fn transform_body(&self, ctx: &mut TransformCtx, body: &mut ExprBody) {
301        // Checked in `should_run`
302        let box_write = self.box_write.unwrap();
303
304        // We are looking for, in (flattened) ULLBC:
305        //
306        // box1 = new_uninit()
307        // ...
308        // ((((*box1)).1).0).0 = [move _4]
309        // box2 = move box1
310        // conditional_drop box1
311        // box3 = move box2
312        // box4 = assume_init(move box3)
313        let rewrites = body
314            .body
315            .iter_enumerated()
316            .filter_map(|(new_uninit_bid, block)| {
317                let TerminatorKind::Call {
318                    call,
319                    target: new_uninit_target,
320                    ..
321                } = &block.terminator.kind
322                else {
323                    return None;
324                };
325
326                if !is_new_uninit_call(ctx, call) {
327                    return None;
328                }
329                // check uninit_box: Box<MaybeUninit<_>>
330                let uninit_box = call.dest.clone();
331                let maybe_uninit_array_ty = box_inner(uninit_box.ty())?;
332                let mu_decl = &ctx.translated.type_decls[maybe_uninit_array_ty.as_adt_id()?];
333                if mu_decl.item_meta.lang_item.as_deref() != Some("maybe_uninit") {
334                    return None;
335                };
336                let uninit_box_l = uninit_box.local_id()?;
337
338                // (*uninit_box).1.0.0 = [payload_elems...]: [elem_ty; len]
339                let payload = find_array_assign(body, *new_uninit_target, uninit_box_l)?;
340
341                // assume_init(uninit_box2)
342                let tail = find_assume_init_tail(ctx, body, payload.loc.after(), &uninit_box)?;
343                let box_array_generics = box_generics(tail.box_array.ty())?;
344
345                Some(Rewrite {
346                    new_uninit_bid,
347                    new_uninit_target: *new_uninit_target,
348                    drop_bid: tail.drop_bid,
349                    target_bid: tail.target_bid,
350                    payload_loc: payload.loc,
351                    move_loc: tail.move_loc,
352                    arg_move_loc: tail.arg_move_loc,
353                    span: payload.span,
354                    payload_elems: payload.payload_elems,
355                    elem_ty: payload.elem_ty,
356                    len: payload.len,
357                    uninit_box,
358                    branched_before_payload: payload.branched_before_payload,
359                    box_array_generics,
360                    assume_init_generics: tail.assume_init_generics,
361                    drop_on_unwind: tail.drop_on_unwind,
362                    box_array: tail.box_array,
363                    assume_init_target: tail.assume_init_target,
364                })
365            });
366
367        for rw in rewrites.collect::<Vec<_>>() {
368            let array_ty = Ty::mk_array(rw.elem_ty.clone(), *rw.len.clone());
369            let array_local = body.locals.new_var(None, array_ty.clone());
370            let box_array_ty = rw.box_array.ty().clone();
371            let box_array_local = body.locals.new_var(None, box_array_ty.clone());
372
373            let array_lid = array_local.as_local().unwrap();
374            let box_array_lid = box_array_local.as_local().unwrap();
375
376            body[rw.move_loc].kind = StatementKind::Nop;
377            body[rw.arg_move_loc].kind = StatementKind::Nop;
378
379            body.body[rw.payload_loc.block].statements.splice(
380                rw.payload_loc.statement..=rw.payload_loc.statement,
381                [
382                    StatementKind::StorageLive(array_lid),
383                    StatementKind::Assign(
384                        array_local.clone(),
385                        Rvalue::Aggregate(
386                            AggregateKind::Array(rw.elem_ty.clone(), rw.len.clone()),
387                            rw.payload_elems,
388                        ),
389                    ),
390                ]
391                .map(|k| Statement::new(rw.span, k)),
392            );
393
394            let (fn_ptr, args) = if rw.branched_before_payload {
395                (
396                    FnPtr::new(
397                        FnPtrKind::Fun(FunId::Regular(box_write)),
398                        rw.assume_init_generics,
399                    ),
400                    vec![
401                        Operand::Move(rw.uninit_box),
402                        Operand::Move(array_local.clone()),
403                    ],
404                )
405            } else {
406                body.body[rw.new_uninit_bid].terminator.kind = TerminatorKind::Goto {
407                    target: rw.new_uninit_target,
408                };
409                (
410                    FnPtr::new(
411                        FnPtrKind::Fun(FunId::Builtin(BuiltinFunId::BoxNew)),
412                        rw.box_array_generics,
413                    ),
414                    vec![Operand::Move(array_local.clone())],
415                )
416            };
417
418            let drop_block = &mut body.body[rw.drop_bid];
419            drop_block.statements.push(Statement::new(
420                rw.span,
421                StatementKind::StorageLive(box_array_lid),
422            ));
423            drop_block.terminator.kind = TerminatorKind::Call {
424                call: Call {
425                    func: FnOperand::Regular(fn_ptr),
426                    args,
427                    dest: box_array_local.clone(),
428                },
429                target: rw.target_bid,
430                on_unwind: rw.drop_on_unwind,
431            };
432
433            let target_block = &mut body.body[rw.target_bid];
434            target_block.statements.push(Statement::new(
435                rw.span,
436                StatementKind::StorageDead(array_lid),
437            ));
438            target_block.statements.push(Statement::new(
439                rw.span,
440                StatementKind::Assign(
441                    rw.box_array,
442                    Rvalue::Use(Operand::Move(box_array_local), WithRetag::No),
443                ),
444            ));
445            target_block.statements.push(Statement::new(
446                rw.span,
447                StatementKind::StorageDead(box_array_lid),
448            ));
449            target_block.terminator.kind = TerminatorKind::Goto {
450                target: rw.assume_init_target,
451            };
452        }
453    }
454}