charon_lib/transform/resugar/
reconstruct_matches.rs

1//! The way to match on enums in MIR is in two steps: first read the discriminant, then switch on
2//! the resulting integer. This pass merges the two into a `SwitchKind::Match` that directly
3//! mentions enum variants.
4use crate::errors::register_error;
5use crate::formatter::IntoFormatter;
6use crate::llbc_ast::*;
7use crate::name_matcher::NamePattern;
8use crate::pretty::FmtWithCtx;
9use crate::transform::TransformCtx;
10use itertools::Itertools;
11use std::collections::{HashMap, HashSet};
12
13use crate::transform::ctx::LlbcPass;
14
15pub struct Transform;
16impl Transform {
17    fn update_block(
18        ctx: &mut TransformCtx,
19        block: &mut Block,
20        discriminant_intrinsics: &HashSet<FunDeclId>,
21    ) {
22        // Iterate through the statements.
23        for i in 0..block.statements.len() {
24            let suffix = &mut block.statements[i..];
25            match suffix {
26                [
27                    Statement {
28                        kind: StatementKind::Assign(dest, Rvalue::Discriminant(p)),
29                        ..
30                    },
31                    rest @ ..,
32                ] => {
33                    // The destination should be a variable
34                    assert!(dest.is_local());
35                    let TyKind::Adt(tdecl_ref) = p.ty().kind() else {
36                        continue;
37                    };
38                    let TypeId::Adt(adt_id) = tdecl_ref.id else {
39                        continue;
40                    };
41
42                    // Lookup the type of the scrutinee
43                    let tkind = ctx.translated.type_decls.get(adt_id).map(|x| &x.kind);
44                    let Some(TypeDeclKind::Enum(variants)) = tkind else {
45                        match tkind {
46                            // This can happen if the type was declared as invisible or opaque.
47                            None | Some(TypeDeclKind::Opaque) => {
48                                let name = ctx.translated.item_name(adt_id).unwrap();
49                                register_error!(
50                                    ctx,
51                                    block.span,
52                                    "reading the discriminant of an opaque enum. \
53                                    Add `--include {}` to the `charon` arguments \
54                                    to translate this enum.",
55                                    name.with_ctx(&ctx.into_fmt())
56                                );
57                            }
58                            // Don't double-error
59                            Some(TypeDeclKind::Error(..)) => {}
60                            Some(_) => {
61                                register_error!(
62                                    ctx,
63                                    block.span,
64                                    "reading the discriminant of a non-enum type"
65                                );
66                            }
67                        }
68                        block.statements[i].kind = StatementKind::Error(
69                            "error reading the discriminant of this type".to_owned(),
70                        );
71                        return;
72                    };
73
74                    // We look for a `SwitchInt` just after the discriminant read.
75                    match rest {
76                        [
77                            Statement {
78                                kind:
79                                    StatementKind::Switch(
80                                        switch @ Switch::SwitchInt(Operand::Move(_), ..),
81                                    ),
82                                ..
83                            },
84                            ..,
85                        ] => {
86                            // Convert between discriminants and variant indices. Remark: the discriminant can
87                            // be of any *signed* integer type (`isize`, `i8`, etc.).
88                            let discr_to_id: HashMap<Literal, VariantId> = variants
89                                .iter_indexed_values()
90                                .map(|(id, variant)| (variant.discriminant.clone(), id))
91                                .collect();
92
93                            take_mut::take(switch, |switch| {
94                                let (Operand::Move(op_p), _, targets, otherwise) =
95                                    switch.to_switch_int().unwrap()
96                                else {
97                                    unreachable!()
98                                };
99                                assert!(op_p.is_local() && op_p.local_id() == dest.local_id());
100
101                                let mut covered_discriminants: HashSet<Literal> =
102                                    HashSet::default();
103                                let targets = targets
104                                    .into_iter()
105                                    .map(|(v, e)| {
106                                        let targets = v
107                                            .into_iter()
108                                            .filter_map(|discr| {
109                                                covered_discriminants.insert(discr.clone());
110                                                discr_to_id.get(&discr).or_else(|| {
111                                                    register_error!(
112                                                        ctx,
113                                                        block.span,
114                                                        "Found incorrect discriminant \
115                                                        {discr} for enum {adt_id}"
116                                                    );
117                                                    None
118                                                })
119                                            })
120                                            .copied()
121                                            .collect_vec();
122                                        (targets, e)
123                                    })
124                                    .collect_vec();
125                                // Filter the otherwise branch if it is not necessary.
126                                let covers_all = covered_discriminants.len() == discr_to_id.len();
127                                let otherwise = if covers_all { None } else { Some(otherwise) };
128
129                                // Replace the old switch with a match.
130                                Switch::Match(p.clone(), targets, otherwise)
131                            });
132                            // `Nop` the discriminant read.
133                            block.statements[i].kind = StatementKind::Nop;
134                        }
135                        _ => {
136                            // The discriminant read is not followed by a `SwitchInt`. This can happen
137                            // in optimized MIR.
138                            continue;
139                        }
140                    }
141                }
142                // Replace calls of `core::intrinsics::discriminant_value` on a known enum with the
143                // appropriate MIR.
144                [
145                    Statement {
146                        kind: StatementKind::Call(call),
147                        ..
148                    },
149                    ..,
150                ] if let FnOperand::Regular(fn_ptr) = &call.func
151                        && let FnPtrKind::Fun(FunId::Regular(fun_id)) = fn_ptr.kind.as_ref()
152                        // Detect a call to the intrinsic...
153                        && discriminant_intrinsics.contains(fun_id)
154                        // passing it a reference.
155                        && let Operand::Move(p) = &call.args[0]
156                        && let TyKind::Ref(_, sub_ty, _) = p.ty().kind() =>
157                {
158                    let p = p.clone().project(ProjectionElem::Deref, sub_ty.clone());
159                    block.statements[i].kind =
160                        StatementKind::Assign(call.dest.clone(), Rvalue::Discriminant(p.clone()))
161                }
162                _ => {}
163            }
164        }
165    }
166}
167
168const DISCRIMINANT_INTRINSIC: &str = "core::intrinsics::discriminant_value";
169
170impl LlbcPass for Transform {
171    fn transform_ctx(&self, ctx: &mut TransformCtx) {
172        let pat = NamePattern::parse(DISCRIMINANT_INTRINSIC).unwrap();
173        // There can be many if we're in mono mode.
174        let discriminant_intrinsic: HashSet<FunDeclId> = ctx
175            .translated
176            .item_names
177            .iter()
178            .filter(|(_, name)| pat.matches(&ctx.translated, name))
179            .filter_map(|(id, _)| id.as_fun())
180            .copied()
181            .collect();
182
183        ctx.for_each_fun_decl(|ctx, decl| {
184            if let Ok(body) = &mut decl.body {
185                let body = body.as_structured_mut().unwrap();
186                self.log_before_body(ctx, &decl.item_meta.name, Ok(&*body));
187                body.body.visit_blocks_bwd(|block: &mut Block| {
188                    Transform::update_block(ctx, block, &discriminant_intrinsic);
189                });
190            };
191        });
192    }
193}