charon_lib/transform/resugar/reconstruct_matches.rs
1//! The way to match on enums in MIR is in two steps: first read the discriminant, then switch on
2//! the resulting integer. This pass merges the two into a `SwitchKind::Match` that directly
3//! mentions enum variants.
4use crate::formatter::IntoFormatter;
5use crate::llbc_ast::*;
6use crate::name_matcher::NamePattern;
7use crate::pretty::FmtWithCtx;
8use crate::transform::TransformCtx;
9use crate::{errors::register_error, transform::CowBox};
10use itertools::Itertools;
11use std::collections::{HashMap, HashSet};
12
13use crate::transform::ctx::LlbcPass;
14
15pub struct Transform {
16 discriminant_intrinsics: HashSet<FunDeclId>,
17}
18
19impl Transform {
20 fn update_block(&self, ctx: &mut TransformCtx, block: &mut Block) {
21 // Iterate through the statements.
22 for i in 0..block.statements.len() {
23 let suffix = &mut block.statements[i..];
24 match suffix {
25 [
26 Statement {
27 kind: StatementKind::Assign(dest, Rvalue::Discriminant(p)),
28 ..
29 },
30 rest @ ..,
31 ] => {
32 // The destination should be a variable
33 assert!(dest.is_local());
34 let TyKind::Adt(tdecl_ref) = p.ty().kind() else {
35 continue;
36 };
37 let TypeId::Adt(adt_id) = tdecl_ref.id else {
38 continue;
39 };
40
41 // Lookup the type of the scrutinee
42 let tkind = ctx.translated.type_decls.get(adt_id).map(|x| &x.kind);
43 let Some(TypeDeclKind::Enum(variants)) = tkind else {
44 match tkind {
45 // This can happen if the type was declared as invisible or opaque.
46 None | Some(TypeDeclKind::Opaque) => {
47 let name = ctx.translated.item_name(adt_id);
48 register_error!(
49 ctx,
50 block.span,
51 "reading the discriminant of an opaque enum. \
52 Add `--include {}` to the `charon` arguments \
53 to translate this enum.",
54 name.with_ctx(&ctx.into_fmt())
55 );
56 }
57 // Don't double-error
58 Some(TypeDeclKind::Error(..)) => {}
59 Some(_) => {
60 register_error!(
61 ctx,
62 block.span,
63 "reading the discriminant of a non-enum type"
64 );
65 }
66 }
67 block.statements[i].kind = StatementKind::Error(
68 "error reading the discriminant of this type".to_owned(),
69 );
70 return;
71 };
72
73 // We look for a `SwitchInt` just after the discriminant read.
74 match rest {
75 [
76 Statement {
77 kind:
78 StatementKind::Switch(
79 switch @ Switch::SwitchInt(Operand::Move(_), ..),
80 ),
81 ..
82 },
83 ..,
84 ] => {
85 // Convert between discriminants and variant indices. Remark: the discriminant can
86 // be of any *signed* integer type (`isize`, `i8`, etc.).
87 let discr_to_id: HashMap<Literal, VariantId> = variants
88 .iter_enumerated()
89 .map(|(id, variant)| (variant.discriminant.clone(), id))
90 .collect();
91
92 take_mut::take(switch, |switch| {
93 let (Operand::Move(op_p), _, targets, otherwise) =
94 switch.to_switch_int().unwrap()
95 else {
96 unreachable!()
97 };
98 assert!(op_p.is_local() && op_p.local_id() == dest.local_id());
99
100 let mut covered_discriminants: HashSet<Literal> =
101 HashSet::default();
102 let targets = targets
103 .into_iter()
104 .map(|(v, e)| {
105 let targets = v
106 .into_iter()
107 .filter_map(|discr| {
108 covered_discriminants.insert(discr.clone());
109 discr_to_id.get(&discr).or_else(|| {
110 register_error!(
111 ctx,
112 block.span,
113 "Found incorrect discriminant \
114 {discr} for enum {adt_id}"
115 );
116 None
117 })
118 })
119 .copied()
120 .collect_vec();
121 (targets, e)
122 })
123 .collect_vec();
124 // Filter the otherwise branch if it is not necessary.
125 let covers_all = covered_discriminants.len() == discr_to_id.len();
126 let otherwise = if covers_all { None } else { Some(otherwise) };
127
128 // Replace the old switch with a match.
129 Switch::Match(p.clone(), targets, otherwise)
130 });
131 // `Nop` the discriminant read.
132 block.statements[i].kind = StatementKind::Nop;
133 }
134 _ => {
135 // The discriminant read is not followed by a `SwitchInt`. This can happen
136 // in optimized MIR.
137 continue;
138 }
139 }
140 }
141 // Replace calls of `core::intrinsics::discriminant_value` on a known enum with the
142 // appropriate MIR.
143 [
144 Statement {
145 kind: StatementKind::Call(call),
146 ..
147 },
148 ..,
149 ] if let FnOperand::Regular(fn_ptr) = &call.func
150 && let FnPtrKind::Fun(FunId::Regular(fun_id)) = fn_ptr.kind.as_ref()
151 // Detect a call to the intrinsic...
152 && self.discriminant_intrinsics.contains(fun_id)
153 // passing it a reference.
154 && let Operand::Move(p) = &call.args[0]
155 && let TyKind::Ref(_, sub_ty, _) = p.ty().kind() =>
156 {
157 let p = p.clone().project(ProjectionElem::Deref, sub_ty.clone());
158 block.statements[i].kind =
159 StatementKind::Assign(call.dest.clone(), Rvalue::Discriminant(p.clone()))
160 }
161 _ => {}
162 }
163 }
164 }
165}
166
167const DISCRIMINANT_INTRINSIC: &str = "core::intrinsics::discriminant_value";
168
169impl Transform {
170 pub fn new(ctx: &mut TransformCtx) -> CowBox<dyn LlbcPass> {
171 let pat = NamePattern::parse(DISCRIMINANT_INTRINSIC).unwrap();
172 // There can be many if we're in mono mode.
173 let discriminant_intrinsics = ctx
174 .translated
175 .item_names
176 .iter()
177 .filter(|(_, name)| pat.matches(&ctx.translated, name))
178 .filter_map(|(id, _)| id.as_fun())
179 .copied()
180 .collect();
181 CowBox::Owned(Box::new(Transform {
182 discriminant_intrinsics,
183 }))
184 }
185}
186
187impl LlbcPass for Transform {
188 fn transform_body(&self, ctx: &mut TransformCtx, body: &mut llbc_ast::ExprBody) {
189 body.body.visit_blocks_bwd(|block: &mut Block| {
190 self.update_block(ctx, block);
191 })
192 }
193}