Skip to main content

charon_lib/transform/resugar/
reconstruct_matches.rs

1//! The way to match on enums in MIR is in two steps: first read the discriminant, then switch on
2//! the resulting integer. This pass merges the two into a `SwitchKind::Match` that directly
3//! mentions enum variants.
4use crate::formatter::IntoFormatter;
5use crate::llbc_ast::*;
6use crate::name_matcher::NamePattern;
7use crate::pretty::FmtWithCtx;
8use crate::transform::TransformCtx;
9use crate::{errors::register_error, transform::CowBox};
10use itertools::Itertools;
11use std::collections::{HashMap, HashSet};
12
13use crate::transform::ctx::LlbcPass;
14
15pub struct Transform {
16    discriminant_intrinsics: HashSet<FunDeclId>,
17}
18
19impl Transform {
20    fn update_block(&self, ctx: &mut TransformCtx, block: &mut Block) {
21        // Iterate through the statements.
22        for i in 0..block.statements.len() {
23            let suffix = &mut block.statements[i..];
24            match suffix {
25                [
26                    Statement {
27                        kind: StatementKind::Assign(dest, Rvalue::Discriminant(p)),
28                        ..
29                    },
30                    rest @ ..,
31                ] => {
32                    // The destination should be a variable
33                    assert!(dest.is_local());
34                    let TyKind::Adt(tdecl_ref) = p.ty().kind() else {
35                        continue;
36                    };
37                    let TypeId::Adt(adt_id) = tdecl_ref.id else {
38                        continue;
39                    };
40
41                    // Lookup the type of the scrutinee
42                    let tkind = ctx.translated.type_decls.get(adt_id).map(|x| &x.kind);
43                    let Some(TypeDeclKind::Enum(variants)) = tkind else {
44                        match tkind {
45                            // This can happen if the type was declared as invisible or opaque.
46                            None | Some(TypeDeclKind::Opaque) => {
47                                let name = ctx.translated.item_name(adt_id);
48                                register_error!(
49                                    ctx,
50                                    block.span,
51                                    "reading the discriminant of an opaque enum. \
52                                    Add `--include {}` to the `charon` arguments \
53                                    to translate this enum.",
54                                    name.with_ctx(&ctx.into_fmt())
55                                );
56                            }
57                            // Don't double-error
58                            Some(TypeDeclKind::Error(..)) => {}
59                            Some(_) => {
60                                register_error!(
61                                    ctx,
62                                    block.span,
63                                    "reading the discriminant of a non-enum type"
64                                );
65                            }
66                        }
67                        block.statements[i].kind = StatementKind::Error(
68                            "error reading the discriminant of this type".to_owned(),
69                        );
70                        return;
71                    };
72
73                    // We look for a `SwitchInt` just after the discriminant read.
74                    match rest {
75                        [
76                            Statement {
77                                kind:
78                                    StatementKind::Switch(
79                                        switch @ Switch::SwitchInt(Operand::Move(_), ..),
80                                    ),
81                                ..
82                            },
83                            ..,
84                        ] => {
85                            // Convert between discriminants and variant indices. Remark: the discriminant can
86                            // be of any *signed* integer type (`isize`, `i8`, etc.).
87                            let discr_to_id: HashMap<Literal, VariantId> = variants
88                                .iter_enumerated()
89                                .map(|(id, variant)| (variant.discriminant.clone(), id))
90                                .collect();
91
92                            take_mut::take(switch, |switch| {
93                                let (Operand::Move(op_p), _, targets, otherwise) =
94                                    switch.to_switch_int().unwrap()
95                                else {
96                                    unreachable!()
97                                };
98                                assert!(op_p.is_local() && op_p.local_id() == dest.local_id());
99
100                                let mut covered_discriminants: HashSet<Literal> =
101                                    HashSet::default();
102                                let targets = targets
103                                    .into_iter()
104                                    .map(|(v, e)| {
105                                        let targets = v
106                                            .into_iter()
107                                            .filter_map(|discr| {
108                                                covered_discriminants.insert(discr.clone());
109                                                discr_to_id.get(&discr).or_else(|| {
110                                                    register_error!(
111                                                        ctx,
112                                                        block.span,
113                                                        "Found incorrect discriminant \
114                                                        {discr} for enum {adt_id}"
115                                                    );
116                                                    None
117                                                })
118                                            })
119                                            .copied()
120                                            .collect_vec();
121                                        (targets, e)
122                                    })
123                                    .collect_vec();
124                                // Filter the otherwise branch if it is not necessary.
125                                let covers_all = covered_discriminants.len() == discr_to_id.len();
126                                let otherwise = if covers_all { None } else { Some(otherwise) };
127
128                                // Replace the old switch with a match.
129                                Switch::Match(p.clone(), targets, otherwise)
130                            });
131                            // `Nop` the discriminant read.
132                            block.statements[i].kind = StatementKind::Nop;
133                        }
134                        _ => {
135                            // The discriminant read is not followed by a `SwitchInt`. This can happen
136                            // in optimized MIR.
137                            continue;
138                        }
139                    }
140                }
141                // Replace calls of `core::intrinsics::discriminant_value` on a known enum with the
142                // appropriate MIR.
143                [
144                    Statement {
145                        kind: StatementKind::Call(call),
146                        ..
147                    },
148                    ..,
149                ] if let FnOperand::Regular(fn_ptr) = &call.func
150                        && let FnPtrKind::Fun(FunId::Regular(fun_id)) = fn_ptr.kind.as_ref()
151                        // Detect a call to the intrinsic...
152                        && self.discriminant_intrinsics.contains(fun_id)
153                        // passing it a reference.
154                        && let Operand::Move(p) = &call.args[0]
155                        && let TyKind::Ref(_, sub_ty, _) = p.ty().kind() =>
156                {
157                    let p = p.clone().project(ProjectionElem::Deref, sub_ty.clone());
158                    block.statements[i].kind =
159                        StatementKind::Assign(call.dest.clone(), Rvalue::Discriminant(p.clone()))
160                }
161                _ => {}
162            }
163        }
164    }
165}
166
167const DISCRIMINANT_INTRINSIC: &str = "core::intrinsics::discriminant_value";
168
169impl Transform {
170    pub fn new(ctx: &mut TransformCtx) -> CowBox<dyn LlbcPass> {
171        let pat = NamePattern::parse(DISCRIMINANT_INTRINSIC).unwrap();
172        // There can be many if we're in mono mode.
173        let discriminant_intrinsics = ctx
174            .translated
175            .item_names
176            .iter()
177            .filter(|(_, name)| pat.matches(&ctx.translated, name))
178            .filter_map(|(id, _)| id.as_fun())
179            .copied()
180            .collect();
181        CowBox::Owned(Box::new(Transform {
182            discriminant_intrinsics,
183        }))
184    }
185}
186
187impl LlbcPass for Transform {
188    fn transform_body(&self, ctx: &mut TransformCtx, body: &mut llbc_ast::ExprBody) {
189        body.body.visit_blocks_bwd(|block: &mut Block| {
190            self.update_block(ctx, block);
191        })
192    }
193}