charon_lib/transform/remove_read_discriminant.rs
1//! The MIR code often contains variables with type `Never`, and we want to get
2//! rid of those. We proceed in two steps. First, we remove the instructions
3//! `drop(v)` where `v` has type `Never` (it can happen - this module does the
4//! filtering). Then, we filter the unused variables ([crate::transform::remove_unused_locals]).
5
6use crate::errors::register_error;
7use crate::formatter::IntoFormatter;
8use crate::llbc_ast::*;
9use crate::name_matcher::NamePattern;
10use crate::pretty::FmtWithCtx;
11use crate::transform::TransformCtx;
12use itertools::Itertools;
13use std::collections::{HashMap, HashSet};
14
15use super::ctx::LlbcPass;
16
17pub struct Transform;
18impl Transform {
19 fn update_block(
20 ctx: &mut TransformCtx,
21 block: &mut Block,
22 discriminant_intrinsic: Option<FunDeclId>,
23 ) {
24 // Iterate through the statements.
25 for i in 0..block.statements.len() {
26 let suffix = &mut block.statements[i..];
27 match suffix {
28 [
29 Statement {
30 content: RawStatement::Assign(dest, Rvalue::Discriminant(p)),
31 ..
32 },
33 rest @ ..,
34 ] => {
35 // The destination should be a variable
36 assert!(dest.is_local());
37 let TyKind::Adt(tdecl_ref) = p.ty().kind() else {
38 continue;
39 };
40 let TypeId::Adt(adt_id) = tdecl_ref.id else {
41 continue;
42 };
43
44 // Lookup the type of the scrutinee
45 let tkind = ctx.translated.type_decls.get(adt_id).map(|x| &x.kind);
46 let Some(TypeDeclKind::Enum(variants)) = tkind else {
47 match tkind {
48 // This can happen if the type was declared as invisible or opaque.
49 None | Some(TypeDeclKind::Opaque) => {
50 let name = ctx.translated.item_name(adt_id).unwrap();
51 register_error!(
52 ctx,
53 block.span,
54 "reading the discriminant of an opaque enum. \
55 Add `--include {}` to the `charon` arguments \
56 to translate this enum.",
57 name.with_ctx(&ctx.into_fmt())
58 );
59 }
60 // Don't double-error
61 Some(TypeDeclKind::Error(..)) => {}
62 Some(_) => {
63 register_error!(
64 ctx,
65 block.span,
66 "reading the discriminant of a non-enum type"
67 );
68 }
69 }
70 block.statements[i].content = RawStatement::Error(
71 "error reading the discriminant of this type".to_owned(),
72 );
73 return;
74 };
75
76 // We look for a `SwitchInt` just after the discriminant read.
77 match rest {
78 [
79 Statement {
80 content:
81 RawStatement::Switch(
82 switch @ Switch::SwitchInt(Operand::Move(_), ..),
83 ),
84 ..
85 },
86 ..,
87 ] => {
88 // Convert between discriminants and variant indices. Remark: the discriminant can
89 // be of any *signed* integer type (`isize`, `i8`, etc.).
90 let discr_to_id: HashMap<ScalarValue, VariantId> = variants
91 .iter_indexed_values()
92 .map(|(id, variant)| (variant.discriminant, id))
93 .collect();
94
95 take_mut::take(switch, |switch| {
96 let (Operand::Move(op_p), _, targets, otherwise) =
97 switch.to_switch_int().unwrap()
98 else {
99 unreachable!()
100 };
101 assert!(op_p.is_local() && op_p.local_id() == dest.local_id());
102
103 let mut covered_discriminants: HashSet<ScalarValue> =
104 HashSet::default();
105 let targets = targets
106 .into_iter()
107 .map(|(v, e)| {
108 let targets = v
109 .into_iter()
110 .filter_map(|discr| {
111 covered_discriminants.insert(discr);
112 discr_to_id.get(&discr).or_else(|| {
113 register_error!(
114 ctx,
115 block.span,
116 "Found incorrect discriminant \
117 {discr} for enum {adt_id}"
118 );
119 None
120 })
121 })
122 .copied()
123 .collect_vec();
124 (targets, e)
125 })
126 .collect_vec();
127 // Filter the otherwise branch if it is not necessary.
128 let covers_all = covered_discriminants.len() == discr_to_id.len();
129 let otherwise = if covers_all { None } else { Some(otherwise) };
130
131 // Replace the old switch with a match.
132 Switch::Match(p.clone(), targets, otherwise)
133 });
134 // `Nop` the discriminant read.
135 block.statements[i].content = RawStatement::Nop;
136 }
137 _ => {
138 // The discriminant read is not followed by a `SwitchInt`. This can happen
139 // in optimized MIR.
140 continue;
141 }
142 }
143 }
144 // Replace calls of `core::intrinsics::discriminant_value` on a known enum with the
145 // appropriate MIR.
146 [
147 Statement {
148 content: RawStatement::Call(call),
149 ..
150 },
151 ..,
152 ] if let Some(discriminant_intrinsic) = discriminant_intrinsic
153 // Detect a call to the intrinsic...
154 && let FnOperand::Regular(fn_ptr) = &call.func
155 && let FunIdOrTraitMethodRef::Fun(FunId::Regular(fun_id)) = fn_ptr.func.as_ref()
156 && *fun_id == discriminant_intrinsic
157 // on a known enum...
158 && let ty = &fn_ptr.generics.types[0]
159 && let TyKind::Adt(ty_ref) = ty.kind()
160 && let TypeId::Adt(type_id) = ty_ref.id
161 && let Some(TypeDeclKind::Enum(_)) =
162 ctx.translated.type_decls.get(type_id).map(|x| &x.kind)
163 // passing it a reference.
164 && let Operand::Move(p) = &call.args[0]
165 && let TyKind::Ref(_, sub_ty, _) = p.ty().kind() =>
166 {
167 let p = p.clone().project(ProjectionElem::Deref, sub_ty.clone());
168 block.statements[i].content =
169 RawStatement::Assign(call.dest.clone(), Rvalue::Discriminant(p.clone()))
170 }
171 _ => {}
172 }
173 }
174 }
175}
176
177const DISCRIMINANT_INTRINSIC: &str = "core::intrinsics::discriminant_value";
178
179impl LlbcPass for Transform {
180 fn transform_ctx(&self, ctx: &mut TransformCtx) {
181 let pat = NamePattern::parse(DISCRIMINANT_INTRINSIC).unwrap();
182 let discriminant_intrinsic: Option<FunDeclId> = ctx
183 .translated
184 .item_names
185 .iter()
186 .find(|(_, name)| pat.matches(&ctx.translated, name))
187 .and_then(|(id, _)| id.as_fun())
188 .copied();
189
190 ctx.for_each_fun_decl(|ctx, decl| {
191 if let Ok(body) = &mut decl.body {
192 let body = body.as_structured_mut().unwrap();
193 self.log_before_body(ctx, &decl.item_meta.name, Ok(&*body));
194 body.body.visit_blocks_bwd(|block: &mut Block| {
195 Transform::update_block(ctx, block, discriminant_intrinsic);
196 });
197 };
198 });
199 }
200}