charon_lib/transform/
remove_read_discriminant.rs

1//! The MIR code often contains variables with type `Never`, and we want to get
2//! rid of those. We proceed in two steps. First, we remove the instructions
3//! `drop(v)` where `v` has type `Never` (it can happen - this module does the
4//! filtering). Then, we filter the unused variables ([crate::transform::remove_unused_locals]).
5
6use crate::errors::register_error;
7use crate::formatter::IntoFormatter;
8use crate::llbc_ast::*;
9use crate::name_matcher::NamePattern;
10use crate::pretty::FmtWithCtx;
11use crate::transform::TransformCtx;
12use itertools::Itertools;
13use std::collections::{HashMap, HashSet};
14
15use super::ctx::LlbcPass;
16
17/// Generate `match _y { 0 => { _x = 0 }, 1 => { _x = 1; }, .. }`.
18fn generate_discr_assignment(
19    span: Span,
20    variants: &Vector<VariantId, Variant>,
21    scrutinee: &Place,
22    dest: &Place,
23) -> RawStatement {
24    let targets = variants
25        .iter_indexed_values()
26        .map(|(id, variant)| {
27            let discr_value =
28                Rvalue::Use(Operand::Const(Box::new(variant.discriminant.to_constant())));
29            let statement = Statement::new(span, RawStatement::Assign(dest.clone(), discr_value));
30            (vec![id], statement.into_block())
31        })
32        .collect();
33    RawStatement::Switch(Switch::Match(scrutinee.clone(), targets, None))
34}
35
36pub struct Transform;
37impl Transform {
38    fn update_block(
39        ctx: &mut TransformCtx,
40        block: &mut Block,
41        discriminant_intrinsic: Option<FunDeclId>,
42    ) {
43        // Iterate through the statements.
44        for i in 0..block.statements.len() {
45            let suffix = &mut block.statements[i..];
46            match suffix {
47                [Statement {
48                    content: RawStatement::Assign(dest, Rvalue::Discriminant(p, adt_id)),
49                    span: span1,
50                    ..
51                }, rest @ ..] => {
52                    // The destination should be a variable
53                    assert!(dest.is_local());
54
55                    // Lookup the type of the scrutinee
56                    let tkind = ctx.translated.type_decls.get(*adt_id).map(|x| &x.kind);
57                    let Some(TypeDeclKind::Enum(variants)) = tkind else {
58                        match tkind {
59                            // This can happen if the type was declared as invisible or opaque.
60                            None | Some(TypeDeclKind::Opaque) => {
61                                let name = ctx.translated.item_name(*adt_id).unwrap();
62                                register_error!(
63                                    ctx,
64                                    block.span,
65                                    "reading the discriminant of an opaque enum. \
66                                    Add `--include {}` to the `charon` arguments \
67                                    to translate this enum.",
68                                    name.with_ctx(&ctx.into_fmt())
69                                );
70                            }
71                            // Don't double-error
72                            Some(TypeDeclKind::Error(..)) => {}
73                            Some(_) => {
74                                register_error!(
75                                    ctx,
76                                    block.span,
77                                    "reading the discriminant of a non-enum type"
78                                );
79                            }
80                        }
81                        block.statements[i].content = RawStatement::Error(
82                            "error reading the discriminant of this type".to_owned(),
83                        );
84                        return;
85                    };
86
87                    // We look for a `SwitchInt` just after the discriminant read.
88                    match rest {
89                        [Statement {
90                            content:
91                                RawStatement::Switch(switch @ Switch::SwitchInt(Operand::Move(_), ..)),
92                            ..
93                        }, ..] => {
94                            // Convert between discriminants and variant indices. Remark: the discriminant can
95                            // be of any *signed* integer type (`isize`, `i8`, etc.).
96                            let discr_to_id: HashMap<ScalarValue, VariantId> = variants
97                                .iter_indexed_values()
98                                .map(|(id, variant)| (variant.discriminant, id))
99                                .collect();
100
101                            take_mut::take(switch, |switch| {
102                                let (Operand::Move(op_p), _, targets, otherwise) =
103                                    switch.to_switch_int().unwrap()
104                                else {
105                                    unreachable!()
106                                };
107                                assert!(op_p.is_local() && op_p.local_id() == dest.local_id());
108
109                                let mut covered_discriminants: HashSet<ScalarValue> =
110                                    HashSet::default();
111                                let targets = targets
112                                    .into_iter()
113                                    .map(|(v, e)| {
114                                        let targets = v
115                                            .into_iter()
116                                            .filter_map(|discr| {
117                                                covered_discriminants.insert(discr);
118                                                discr_to_id.get(&discr).or_else(|| {
119                                                    register_error!(
120                                                        ctx,
121                                                        block.span,
122                                                        "Found incorrect discriminant \
123                                                        {discr} for enum {adt_id}"
124                                                    );
125                                                    None
126                                                })
127                                            })
128                                            .copied()
129                                            .collect_vec();
130                                        (targets, e)
131                                    })
132                                    .collect_vec();
133                                // Filter the otherwise branch if it is not necessary.
134                                let covers_all = covered_discriminants.len() == discr_to_id.len();
135                                let otherwise = if covers_all { None } else { Some(otherwise) };
136
137                                // Replace the old switch with a match.
138                                Switch::Match(p.clone(), targets, otherwise)
139                            });
140                            // `Nop` the discriminant read.
141                            block.statements[i].content = RawStatement::Nop;
142                        }
143                        _ => {
144                            // The discriminant read is not followed by a `SwitchInt`. This can happen
145                            // in optimized MIR. We replace `_x = Discr(_y)` with `match _y { 0 => { _x
146                            // = 0 }, 1 => { _x = 1; }, .. }`.
147                            block.statements[i].content =
148                                generate_discr_assignment(*span1, variants, p, dest)
149                        }
150                    }
151                }
152                // Replace calls of `core::intrinsics::discriminant_value` on a known enum with the
153                // appropriate MIR.
154                [Statement {
155                    content: RawStatement::Call(call),
156                    span: span1,
157                    ..
158                }, ..]
159                    if let Some(discriminant_intrinsic) = discriminant_intrinsic
160                        // Detect a call to the intrinsic...
161                        && let FnOperand::Regular(fn_ptr) = &call.func
162                        && let FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id)) = fn_ptr.func.as_ref()
163                        && *fun_id == discriminant_intrinsic
164                        // on a known enum...
165                        && let ty = &fn_ptr.generics.types[0]
166                        && let TyKind::Adt(TypeId::Adt(type_id), _) = *ty.kind()
167                        && let Some(TypeDeclKind::Enum(variants)) =
168                            ctx.translated.type_decls.get(type_id).map(|x| &x.kind)
169                        // passing it a reference.
170                        && let Operand::Move(p) = &call.args[0]
171                        && let TyKind::Ref(_, sub_ty, _) = p.ty().kind() =>
172                {
173                    let p = p.clone().project(ProjectionElem::Deref, sub_ty.clone());
174                    block.statements[i].content =
175                        generate_discr_assignment(*span1, variants, &p, &call.dest)
176                }
177                _ => {}
178            }
179        }
180    }
181}
182
183const DISCRIMINANT_INTRINSIC: &str = "core::intrinsics::discriminant_value";
184
185impl LlbcPass for Transform {
186    fn transform_ctx(&self, ctx: &mut TransformCtx) {
187        let pat = NamePattern::parse(DISCRIMINANT_INTRINSIC).unwrap();
188        let discriminant_intrinsic: Option<FunDeclId> = ctx
189            .translated
190            .item_names
191            .iter()
192            .find(|(_, name)| pat.matches(&ctx.translated, name))
193            .and_then(|(id, _)| id.as_fun())
194            .copied();
195
196        ctx.for_each_fun_decl(|ctx, decl| {
197            if let Ok(body) = &mut decl.body {
198                let body = body.as_structured_mut().unwrap();
199                self.log_before_body(ctx, &decl.item_meta.name, Ok(&*body));
200                body.body.visit_blocks_bwd(|block: &mut Block| {
201                    Transform::update_block(ctx, block, discriminant_intrinsic);
202                });
203            };
204        });
205    }
206}