charon_lib/transform/
inline_local_panic_functions.rs

1//! `panic!()` expands to:
2//! ```ignore
3//! fn panic_cold_explicit() -> ! {
4//!     core::panicking::panic_explicit()
5//! }
6//! panic_cold_explicit()
7//! ```
8//! Which defines a new function each time. This pass recognizes these functions and replaces calls
9//! to them by a `Panic` terminator.
10use std::collections::HashSet;
11
12use super::{ctx::UllbcPass, TransformCtx};
13use crate::{builtins, names::Name, ullbc_ast::*};
14
15pub struct Transform;
16impl UllbcPass for Transform {
17    fn transform_ctx(&self, ctx: &mut TransformCtx) {
18        // Collect the functions that were generated by the `panic!` macro.
19        let mut panic_fns = HashSet::new();
20        ctx.for_each_fun_decl(|_ctx, decl| {
21            if let Ok(body) = &mut decl.body {
22                let body = body.as_unstructured().unwrap();
23                // If the whole body is only a call to this specific panic function.
24                if body.body.elem_count() == 1
25                    && let Some(block) = body.body.iter().next()
26                    && block.statements.is_empty()
27                    && let RawTerminator::Abort(AbortKind::Panic(Some(name))) =
28                        &block.terminator.content
29                {
30                    if name.equals_ref_name(builtins::EXPLICIT_PANIC_NAME) {
31                        // FIXME: also check that the name of the function is
32                        // `panic_cold_explicit`?
33                        panic_fns.insert(decl.def_id);
34                    }
35                }
36            }
37        });
38
39        let panic_name = Name::from_path(builtins::EXPLICIT_PANIC_NAME);
40        let panic_terminator = RawTerminator::Abort(AbortKind::Panic(Some(panic_name)));
41
42        // Replace each call to one such function with a `Panic`.
43        ctx.for_each_fun_decl(|_ctx, decl| {
44            if let Ok(body) = &mut decl.body {
45                let body = body.as_unstructured_mut().unwrap();
46                for block_id in body.body.all_indices() {
47                    let Some(block) = body.body.get_mut(block_id) else {
48                        continue;
49                    };
50                    for i in 0..block.statements.len() {
51                        let st = &block.statements[i];
52                        if let RawStatement::Call(Call {
53                            func:
54                                FnOperand::Regular(FnPtr {
55                                    func: FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id)),
56                                    ..
57                                }),
58                            ..
59                        }) = &st.content
60                            && panic_fns.contains(fun_id)
61                        {
62                            block.statements.drain(i..);
63                            block.terminator.content = panic_terminator.clone();
64                            break;
65                        }
66                    }
67                }
68            }
69        });
70
71        // Remove these functions from the context.
72        for id in &panic_fns {
73            ctx.translated.fun_decls.remove(*id);
74        }
75    }
76}