rustc_mir_transform/early_otherwise_branch.rs
1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15/// (Some(_), Some(_)) => {0},
16/// (None, None) => {2},
17/// _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27/// match x {
28/// Some(_) => 0,
29/// None => 2,
30/// }
31/// } else {
32/// 1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39/// =================
40/// | BB1 |
41/// |---------------| ============================
42/// | ... | /------> | BBC |
43/// |---------------| | |--------------------------|
44/// | switchInt(Q) | | | _cl = discriminant(P) |
45/// | c | --------/ |--------------------------|
46/// | d | -------\ | switchInt(_cl) |
47/// | ... | | | c | ---> BBC.2
48/// | otherwise | --\ | /--- | otherwise |
49/// ================= | | | ============================
50/// | | |
51/// ================= | | |
52/// | BBU | <-| | | ============================
53/// |---------------| \-------> | BBD |
54/// |---------------| | |--------------------------|
55/// | unreachable | | | _dl = discriminant(P) |
56/// ================= | |--------------------------|
57/// | | switchInt(_dl) |
58/// ================= | | d | ---> BBD.2
59/// | BB9 | <--------------- | otherwise |
60/// |---------------| ============================
61/// | ... |
62/// =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66/// - `BB1` is `parent` and `BBC, BBD` are children
67/// - `P` is `child_place`
68/// - `child_ty` is the type of `_cl`.
69/// - `Q` is `parent_op`.
70/// - `parent_ty` is the type of `Q`.
71/// - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75/// =======================
76/// | BB1 |
77/// |---------------------| ============================
78/// | ... | /------> | BBEq |
79/// | _s = discriminant(P)| | |--------------------------|
80/// | _t = Ne(Q, _s) | | |--------------------------|
81/// |---------------------| | | switchInt(Q) |
82/// | switchInt(_t) | | | c | ---> BBC.2
83/// | false | --------/ | d | ---> BBD.2
84/// | otherwise | /--------- | otherwise |
85/// ======================= | ============================
86/// |
87/// ================= |
88/// | BB9 | <-----------/
89/// |---------------|
90/// | ... |
91/// =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97 sess.mir_opt_level() >= 2
98 }
99
100 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103 let mut should_cleanup = false;
104
105 // Also consider newly generated bbs in the same pass
106 for parent in body.basic_blocks.indices() {
107 let bbs = &*body.basic_blocks;
108 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110 trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112 should_cleanup = true;
113
114 let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115 &bbs[parent].terminator().kind
116 else {
117 unreachable!()
118 };
119 // Always correct since we can only switch on `Copy` types
120 let parent_op = match parent_op {
121 Operand::Move(x) => Operand::Copy(*x),
122 Operand::Copy(x) => Operand::Copy(*x),
123 Operand::Constant(x) => Operand::Constant(x.clone()),
124 };
125 let parent_ty = parent_op.ty(body.local_decls(), tcx);
126 let statements_before = bbs[parent].statements.len();
127 let parent_end = Location { block: parent, statement_index: statements_before };
128
129 let mut patch = MirPatch::new(body);
130
131 let second_operand = if opt_data.need_hoist_discriminant {
132 // create temp to store second discriminant in, `_s` in example above
133 let second_discriminant_temp =
134 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
135
136 // create assignment of discriminant
137 patch.add_assign(
138 parent_end,
139 Place::from(second_discriminant_temp),
140 Rvalue::Discriminant(opt_data.child_place),
141 );
142 Operand::Move(Place::from(second_discriminant_temp))
143 } else {
144 Operand::Copy(opt_data.child_place)
145 };
146
147 // create temp to store inequality comparison between the two discriminants, `_t` in
148 // example above
149 let nequal = BinOp::Ne;
150 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
151 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
152
153 // create inequality comparison
154 let comp_rvalue =
155 Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
156 patch.add_statement(
157 parent_end,
158 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
159 );
160
161 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
162 let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
163 else {
164 unreachable!()
165 };
166 (value, targets.target_for_value(value))
167 });
168 // The otherwise either is the same target branch or an unreachable.
169 let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
170
171 // Create `bbEq` in example above
172 let eq_switch = BasicBlockData::new(
173 Some(Terminator {
174 source_info: bbs[parent].terminator().source_info,
175 kind: TerminatorKind::SwitchInt {
176 // switch on the first discriminant, so we can mark the second one as dead
177 discr: parent_op,
178 targets: eq_targets,
179 },
180 }),
181 bbs[parent].is_cleanup,
182 );
183
184 let eq_bb = patch.new_block(eq_switch);
185
186 // Jump to it on the basis of the inequality comparison
187 let true_case = opt_data.destination;
188 let false_case = eq_bb;
189 patch.patch_terminator(
190 parent,
191 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
192 );
193
194 patch.apply(body);
195 }
196
197 // Since this optimization adds new basic blocks and invalidates others,
198 // clean up the cfg to make it nicer for other passes
199 if should_cleanup {
200 simplify_cfg(tcx, body);
201 }
202 }
203
204 fn is_required(&self) -> bool {
205 false
206 }
207}
208
209#[derive(Debug)]
210struct OptimizationData<'tcx> {
211 destination: BasicBlock,
212 child_place: Place<'tcx>,
213 child_ty: Ty<'tcx>,
214 child_source: SourceInfo,
215 need_hoist_discriminant: bool,
216}
217
218fn evaluate_candidate<'tcx>(
219 tcx: TyCtxt<'tcx>,
220 body: &Body<'tcx>,
221 parent: BasicBlock,
222) -> Option<OptimizationData<'tcx>> {
223 let bbs = &body.basic_blocks;
224 // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
225 if bbs[parent].is_cleanup {
226 return None;
227 }
228 let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
229 else {
230 return None;
231 };
232 let parent_ty = parent_discr.ty(body.local_decls(), tcx);
233 let (_, child) = targets.iter().next()?;
234
235 let Terminator {
236 kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
237 source_info,
238 } = bbs[child].terminator()
239 else {
240 return None;
241 };
242 let child_ty = child_discr.ty(body.local_decls(), tcx);
243 if child_ty != parent_ty {
244 return None;
245 }
246
247 // We only handle:
248 // ```
249 // bb4: {
250 // _8 = discriminant((_3.1: Enum1));
251 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
252 // }
253 // ```
254 // and
255 // ```
256 // bb2: {
257 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
258 // }
259 // ```
260 if bbs[child].statements.len() > 1 {
261 return None;
262 }
263
264 // When thie BB has exactly one statement, this statement should be discriminant.
265 let need_hoist_discriminant = bbs[child].statements.len() == 1;
266 let child_place = if need_hoist_discriminant {
267 if !bbs[targets.otherwise()].is_empty_unreachable() {
268 // Someone could write code like this:
269 // ```rust
270 // let Q = val;
271 // if discriminant(P) == otherwise {
272 // let ptr = &mut Q as *mut _ as *mut u8;
273 // // It may be difficult for us to effectively determine whether values are valid.
274 // // Invalid values can come from all sorts of corners.
275 // unsafe { *ptr = 10; }
276 // }
277 //
278 // match P {
279 // A => match Q {
280 // A => {
281 // // code
282 // }
283 // _ => {
284 // // don't use Q
285 // }
286 // }
287 // _ => {
288 // // don't use Q
289 // }
290 // };
291 // ```
292 //
293 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
294 // invalid value, which is UB.
295 // In order to fix this, **we would either need to show that the discriminant computation of
296 // `place` is computed in all branches**.
297 // FIXME(#95162) For the moment, we adopt a conservative approach and
298 // consider only the `otherwise` branch has no statements and an unreachable terminator.
299 return None;
300 }
301 // Handle:
302 // ```
303 // bb4: {
304 // _8 = discriminant((_3.1: Enum1));
305 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
306 // }
307 // ```
308 let [
309 Statement {
310 kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
311 ..
312 },
313 ] = bbs[child].statements.as_slice()
314 else {
315 return None;
316 };
317 *child_place
318 } else {
319 // Handle:
320 // ```
321 // bb2: {
322 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
323 // }
324 // ```
325 let Operand::Copy(child_place) = child_discr else {
326 return None;
327 };
328 *child_place
329 };
330 let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
331 {
332 child_targets.otherwise()
333 } else {
334 targets.otherwise()
335 };
336
337 // Verify that the optimization is legal for each branch
338 for (value, child) in targets.iter() {
339 if !verify_candidate_branch(
340 &bbs[child],
341 value,
342 child_place,
343 destination,
344 need_hoist_discriminant,
345 ) {
346 return None;
347 }
348 }
349 Some(OptimizationData {
350 destination,
351 child_place,
352 child_ty,
353 child_source: *source_info,
354 need_hoist_discriminant,
355 })
356}
357
358fn verify_candidate_branch<'tcx>(
359 branch: &BasicBlockData<'tcx>,
360 value: u128,
361 place: Place<'tcx>,
362 destination: BasicBlock,
363 need_hoist_discriminant: bool,
364) -> bool {
365 // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
366 let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
367 return false;
368 };
369 if need_hoist_discriminant {
370 // If we need hoist discriminant, the branch must have exactly one statement.
371 let [statement] = branch.statements.as_slice() else {
372 return false;
373 };
374 // The statement must assign the discriminant of `place`.
375 let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
376 statement.kind
377 else {
378 return false;
379 };
380 if from_place != place {
381 return false;
382 }
383 // The assignment must invalidate a local that terminate on a `SwitchInt`.
384 if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
385 return false;
386 }
387 } else {
388 // If we don't need hoist discriminant, the branch must not have any statements.
389 if !branch.statements.is_empty() {
390 return false;
391 }
392 // The place on `SwitchInt` must be the same.
393 if *switch_op != Operand::Copy(place) {
394 return false;
395 }
396 }
397 // It must fall through to `destination` if the switch misses.
398 if destination != targets.otherwise() {
399 return false;
400 }
401 // It must have exactly one branch for value `value` and have no more branches.
402 let mut iter = targets.iter();
403 let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
404 return false;
405 };
406 target_value == value
407}