charon_lib/transform/resugar/reconstruct_matches.rs
1//! The way to match on enums in MIR is in two steps: first read the discriminant, then switch on
2//! the resulting integer. This pass merges the two into a `SwitchKind::Match` that directly
3//! mentions enum variants.
4use crate::errors::register_error;
5use crate::formatter::IntoFormatter;
6use crate::llbc_ast::*;
7use crate::name_matcher::NamePattern;
8use crate::pretty::FmtWithCtx;
9use crate::transform::TransformCtx;
10use itertools::Itertools;
11use std::collections::{HashMap, HashSet};
12
13use crate::transform::ctx::LlbcPass;
14
15pub struct Transform;
16impl Transform {
17 fn update_block(
18 ctx: &mut TransformCtx,
19 block: &mut Block,
20 discriminant_intrinsics: &HashSet<FunDeclId>,
21 ) {
22 // Iterate through the statements.
23 for i in 0..block.statements.len() {
24 let suffix = &mut block.statements[i..];
25 match suffix {
26 [
27 Statement {
28 kind: StatementKind::Assign(dest, Rvalue::Discriminant(p)),
29 ..
30 },
31 rest @ ..,
32 ] => {
33 // The destination should be a variable
34 assert!(dest.is_local());
35 let TyKind::Adt(tdecl_ref) = p.ty().kind() else {
36 continue;
37 };
38 let TypeId::Adt(adt_id) = tdecl_ref.id else {
39 continue;
40 };
41
42 // Lookup the type of the scrutinee
43 let tkind = ctx.translated.type_decls.get(adt_id).map(|x| &x.kind);
44 let Some(TypeDeclKind::Enum(variants)) = tkind else {
45 match tkind {
46 // This can happen if the type was declared as invisible or opaque.
47 None | Some(TypeDeclKind::Opaque) => {
48 let name = ctx.translated.item_name(adt_id).unwrap();
49 register_error!(
50 ctx,
51 block.span,
52 "reading the discriminant of an opaque enum. \
53 Add `--include {}` to the `charon` arguments \
54 to translate this enum.",
55 name.with_ctx(&ctx.into_fmt())
56 );
57 }
58 // Don't double-error
59 Some(TypeDeclKind::Error(..)) => {}
60 Some(_) => {
61 register_error!(
62 ctx,
63 block.span,
64 "reading the discriminant of a non-enum type"
65 );
66 }
67 }
68 block.statements[i].kind = StatementKind::Error(
69 "error reading the discriminant of this type".to_owned(),
70 );
71 return;
72 };
73
74 // We look for a `SwitchInt` just after the discriminant read.
75 match rest {
76 [
77 Statement {
78 kind:
79 StatementKind::Switch(
80 switch @ Switch::SwitchInt(Operand::Move(_), ..),
81 ),
82 ..
83 },
84 ..,
85 ] => {
86 // Convert between discriminants and variant indices. Remark: the discriminant can
87 // be of any *signed* integer type (`isize`, `i8`, etc.).
88 let discr_to_id: HashMap<Literal, VariantId> = variants
89 .iter_indexed_values()
90 .map(|(id, variant)| (variant.discriminant.clone(), id))
91 .collect();
92
93 take_mut::take(switch, |switch| {
94 let (Operand::Move(op_p), _, targets, otherwise) =
95 switch.to_switch_int().unwrap()
96 else {
97 unreachable!()
98 };
99 assert!(op_p.is_local() && op_p.local_id() == dest.local_id());
100
101 let mut covered_discriminants: HashSet<Literal> =
102 HashSet::default();
103 let targets = targets
104 .into_iter()
105 .map(|(v, e)| {
106 let targets = v
107 .into_iter()
108 .filter_map(|discr| {
109 covered_discriminants.insert(discr.clone());
110 discr_to_id.get(&discr).or_else(|| {
111 register_error!(
112 ctx,
113 block.span,
114 "Found incorrect discriminant \
115 {discr} for enum {adt_id}"
116 );
117 None
118 })
119 })
120 .copied()
121 .collect_vec();
122 (targets, e)
123 })
124 .collect_vec();
125 // Filter the otherwise branch if it is not necessary.
126 let covers_all = covered_discriminants.len() == discr_to_id.len();
127 let otherwise = if covers_all { None } else { Some(otherwise) };
128
129 // Replace the old switch with a match.
130 Switch::Match(p.clone(), targets, otherwise)
131 });
132 // `Nop` the discriminant read.
133 block.statements[i].kind = StatementKind::Nop;
134 }
135 _ => {
136 // The discriminant read is not followed by a `SwitchInt`. This can happen
137 // in optimized MIR.
138 continue;
139 }
140 }
141 }
142 // Replace calls of `core::intrinsics::discriminant_value` on a known enum with the
143 // appropriate MIR.
144 [
145 Statement {
146 kind: StatementKind::Call(call),
147 ..
148 },
149 ..,
150 ] if let FnOperand::Regular(fn_ptr) = &call.func
151 && let FnPtrKind::Fun(FunId::Regular(fun_id)) = fn_ptr.kind.as_ref()
152 // Detect a call to the intrinsic...
153 && discriminant_intrinsics.contains(fun_id)
154 // passing it a reference.
155 && let Operand::Move(p) = &call.args[0]
156 && let TyKind::Ref(_, sub_ty, _) = p.ty().kind() =>
157 {
158 let p = p.clone().project(ProjectionElem::Deref, sub_ty.clone());
159 block.statements[i].kind =
160 StatementKind::Assign(call.dest.clone(), Rvalue::Discriminant(p.clone()))
161 }
162 _ => {}
163 }
164 }
165 }
166}
167
168const DISCRIMINANT_INTRINSIC: &str = "core::intrinsics::discriminant_value";
169
170impl LlbcPass for Transform {
171 fn transform_ctx(&self, ctx: &mut TransformCtx) {
172 let pat = NamePattern::parse(DISCRIMINANT_INTRINSIC).unwrap();
173 // There can be many if we're in mono mode.
174 let discriminant_intrinsic: HashSet<FunDeclId> = ctx
175 .translated
176 .item_names
177 .iter()
178 .filter(|(_, name)| pat.matches(&ctx.translated, name))
179 .filter_map(|(id, _)| id.as_fun())
180 .copied()
181 .collect();
182
183 ctx.for_each_fun_decl(|ctx, decl| {
184 if let Ok(body) = &mut decl.body {
185 let body = body.as_structured_mut().unwrap();
186 self.log_before_body(ctx, &decl.item_meta.name, Ok(&*body));
187 body.body.visit_blocks_bwd(|block: &mut Block| {
188 Transform::update_block(ctx, block, &discriminant_intrinsic);
189 });
190 };
191 });
192 }
193}