Skip to main content

charon_lib/transform/simplify_output/
inline_selected_functions.rs

1use std::{collections::HashMap, mem};
2
3use crate::transform::CowBox;
4use crate::transform::{TransformCtx, ctx::UllbcPass};
5use crate::ullbc_ast::*;
6
7pub struct Transform {
8    to_inline: HashMap<FunDeclId, FunDecl>,
9}
10
11impl Transform {
12    pub fn new(ctx: &mut TransformCtx) -> CowBox<dyn UllbcPass> {
13        let panic_name = Name::from_path(builtins::EXPLICIT_PANIC_NAME);
14        let panic_terminator = TerminatorKind::Abort(AbortKind::Panic(Some(panic_name)));
15
16        // Collect and remove the functions that we want to inline.
17        let to_inline = ctx
18            .translated
19            .fun_decls
20            .extract(|_, decl| {
21                decl.body.as_unstructured().is_some_and(|body| {
22                    // If the whole body is only a call to this specific panic function.
23                    // FIXME: also check that the name of the function is `panic_cold_explicit`?
24                    let is_local_panic_fn = body.body.len() == 1 && {
25                        let block = &body.body[0];
26                        block.statements.is_empty() && block.terminator.kind == panic_terminator
27                    };
28                    // The `anon_consts_to_call` pass already transformed references to anon consts
29                    // into calls to their initializers so we only have to inline these.
30                    let is_anon_const_initializer = decl
31                        .is_global_initializer
32                        .and_then(|gid| ctx.translated.global_decls.get(gid))
33                        .is_some_and(|gdecl| matches!(gdecl.global_kind, GlobalKind::AnonConst));
34                    let is_vec_construction_fn = decl.item_meta.lang_item.as_deref()
35                        == Some(builtins::BOX_ASSUME_INIT_INTO_VEC_UNSAFE);
36                    is_local_panic_fn
37                        || (is_anon_const_initializer && !ctx.options.raw_consts)
38                        || (is_vec_construction_fn && ctx.options.treat_box_as_builtin)
39                })
40            })
41            .collect();
42
43        CowBox::Owned(Box::new(Transform { to_inline }))
44    }
45}
46impl UllbcPass for Transform {
47    fn should_run(&self, _options: &crate::options::TranslateOptions) -> bool {
48        !self.to_inline.is_empty()
49    }
50    fn apply_preceding_passes(&mut self, ctx: &mut TransformCtx, passes: &[CowBox<dyn UllbcPass>]) {
51        for decl in self.to_inline.values_mut() {
52            for pass in passes {
53                pass.transform_item(ctx, ItemRefMut::Fun(decl));
54            }
55        }
56    }
57    fn transform_body(&self, _ctx: &mut TransformCtx, outer_body: &mut ullbc_ast::ExprBody) {
58        for block_id in outer_body.body.indices() {
59            let Some(block) = outer_body.body.get_mut(block_id) else {
60                continue;
61            };
62            let TerminatorKind::Call {
63                call: Call { func, args, dest },
64                target,
65                on_unwind,
66            } = &mut block.terminator.kind
67            else {
68                continue;
69            };
70            let target = *target;
71            let on_unwind = *on_unwind;
72            let dest_place = dest.clone();
73            let args = args.clone();
74            let FnOperand::Regular(fn_ptr) = &func else {
75                continue;
76            };
77            let FnPtrKind::Fun(FunId::Regular(fun_id)) = fn_ptr.kind.as_ref() else {
78                continue;
79            };
80            let Some(initializer) = self.to_inline.get(fun_id) else {
81                continue;
82            };
83            let span = initializer.item_meta.span;
84            let Some(inner_body) = initializer.body.as_unstructured() else {
85                continue;
86            };
87
88            // We inline the required body by shifting its local ids and block ids
89            // and adding its blocks to the outer body. The inner body's return
90            // local becomes a normal local that we can read from. We redirect some
91            // gotos so that the inner body is executed before the current block.
92            let mut inner_body = {
93                let mut inner_body = inner_body.clone();
94                let inner_bound = inner_body.bound_body_regions;
95
96                // Shift all the body regions in the inner body BEFORE substitution,
97                // so that we only shift the inner body's own regions.
98                inner_body.dyn_visit_mut(|r: &mut Region| {
99                    if let Region::Body(v) = r {
100                        *v += outer_body.bound_body_regions;
101                    }
102                });
103                outer_body.bound_body_regions += inner_bound;
104
105                // Now substitute generics. This may inject outer-body Region::Body
106                // IDs, which is correct since they don't need shifting.
107                inner_body.substitute(&fn_ptr.generics)
108            };
109
110            let return_local = outer_body.locals.locals.next_idx();
111            inner_body.dyn_visit_in_body_mut(|l: &mut LocalId| {
112                *l += return_local;
113            });
114            outer_body
115                .locals
116                .locals
117                .extend(mem::take(&mut inner_body.locals.locals));
118
119            // The inner body assumes the return and arg places are live; allocate them, and
120            // initialize the args.
121            inner_body.body[0].statements.splice(
122                0..0,
123                [StatementKind::StorageLive(return_local)]
124                    .into_iter()
125                    .chain(args.into_iter().enumerate().flat_map(|(i, arg)| {
126                        let arg_local = return_local + i + 1;
127                        let arg_place = outer_body.locals.place_for_var(arg_local);
128                        [
129                            StatementKind::StorageLive(arg_local),
130                            StatementKind::Assign(arg_place, Rvalue::Use(arg, WithRetag::Yes)),
131                        ]
132                    }))
133                    .map(|kind| Statement::new(span, kind)),
134            );
135
136            let mut final_block = BlockData::new_goto(span, target);
137
138            // The inner body will write to `return_place`, but the outer body expects the value at
139            // `dest_place`.
140            let return_place = outer_body.locals.place_for_var(return_local);
141            final_block.statements.push(Statement::new(
142                span,
143                StatementKind::Assign(
144                    dest_place,
145                    Rvalue::Use(Operand::Move(return_place), WithRetag::Yes),
146                ),
147            ));
148            let final_block = outer_body.body.push(final_block);
149
150            // Shift all block ids in the inner body and point return/unwind to where they should.
151            let start_block = outer_body.body.next_idx();
152            inner_body.dyn_visit_in_body_mut(|b: &mut BlockId| {
153                *b += start_block;
154            });
155            inner_body
156                .body
157                .dyn_visit_in_body_mut(|t: &mut Terminator| match t.kind {
158                    TerminatorKind::Return => {
159                        t.kind = TerminatorKind::Goto {
160                            target: final_block,
161                        };
162                    }
163                    TerminatorKind::UnwindResume => {
164                        t.kind = TerminatorKind::Goto { target: on_unwind };
165                    }
166                    _ => (),
167                });
168            // At the end of the current block, start evaluating the inner body.
169            outer_body.body[block_id].terminator.kind = TerminatorKind::Goto {
170                target: start_block,
171            };
172            // Add the blocks for the inner body.
173            outer_body.body.extend(inner_body.body);
174        }
175    }
176}