rustc_mir_transform/
early_otherwise_branch.rs

1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15///     (Some(_), Some(_)) => {0},
16///     (None, None) => {2},
17///     _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27///     match x {
28///         Some(_) => 0,
29///         None => 2,
30///     }
31/// } else {
32///     1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39///     =================
40///     |      BB1      |
41///     |---------------|                  ============================
42///     |     ...       |         /------> |            BBC           |
43///     |---------------|         |        |--------------------------|
44///     |  switchInt(Q) |         |        |   _cl = discriminant(P)  |
45///     |       c       | --------/        |--------------------------|
46///     |       d       | -------\         |       switchInt(_cl)     |
47///     |      ...      |        |         |            c             | ---> BBC.2
48///     |    otherwise  | --\    |    /--- |         otherwise        |
49///     =================   |    |    |    ============================
50///                         |    |    |
51///     =================   |    |    |
52///     |      BBU      | <-|    |    |    ============================
53///     |---------------|        \-------> |            BBD           |
54///     |---------------|             |    |--------------------------|
55///     |  unreachable  |             |    |   _dl = discriminant(P)  |
56///     =================             |    |--------------------------|
57///                                   |    |       switchInt(_dl)     |
58///     =================             |    |            d             | ---> BBD.2
59///     |      BB9      | <--------------- |         otherwise        |
60///     |---------------|                  ============================
61///     |      ...      |
62///     =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66///  - `BB1` is `parent` and `BBC, BBD` are children
67///  - `P` is `child_place`
68///  - `child_ty` is the type of `_cl`.
69///  - `Q` is `parent_op`.
70///  - `parent_ty` is the type of `Q`.
71///  - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75///     =======================
76///     |          BB1        |
77///     |---------------------|                  ============================
78///     |          ...        |         /------> |           BBEq           |
79///     | _s = discriminant(P)|         |        |--------------------------|
80///     | _t = Ne(Q, _s)      |         |        |--------------------------|
81///     |---------------------|         |        |       switchInt(Q)       |
82///     |     switchInt(_t)   |         |        |            c             | ---> BBC.2
83///     |        false        | --------/        |            d             | ---> BBD.2
84///     |       otherwise     |       /--------- |         otherwise        |
85///     =======================       |          ============================
86///                                   |
87///     =================             |
88///     |      BB9      | <-----------/
89///     |---------------|
90///     |      ...      |
91///     =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97        sess.mir_opt_level() >= 2
98    }
99
100    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101        trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103        let mut should_cleanup = false;
104
105        // Also consider newly generated bbs in the same pass
106        for parent in body.basic_blocks.indices() {
107            let bbs = &*body.basic_blocks;
108            let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110            trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112            should_cleanup = true;
113
114            let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115                &bbs[parent].terminator().kind
116            else {
117                unreachable!()
118            };
119            // Always correct since we can only switch on `Copy` types
120            let parent_op = match parent_op {
121                Operand::Move(x) => Operand::Copy(*x),
122                Operand::Copy(x) => Operand::Copy(*x),
123                Operand::Constant(x) => Operand::Constant(x.clone()),
124            };
125            let parent_ty = parent_op.ty(body.local_decls(), tcx);
126            let statements_before = bbs[parent].statements.len();
127            let parent_end = Location { block: parent, statement_index: statements_before };
128
129            let mut patch = MirPatch::new(body);
130
131            let second_operand = if opt_data.need_hoist_discriminant {
132                // create temp to store second discriminant in, `_s` in example above
133                let second_discriminant_temp =
134                    patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
135
136                // create assignment of discriminant
137                patch.add_assign(
138                    parent_end,
139                    Place::from(second_discriminant_temp),
140                    Rvalue::Discriminant(opt_data.child_place),
141                );
142                Operand::Move(Place::from(second_discriminant_temp))
143            } else {
144                Operand::Copy(opt_data.child_place)
145            };
146
147            // create temp to store inequality comparison between the two discriminants, `_t` in
148            // example above
149            let nequal = BinOp::Ne;
150            let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
151            let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
152
153            // create inequality comparison
154            let comp_rvalue =
155                Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
156            patch.add_statement(
157                parent_end,
158                StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
159            );
160
161            let eq_new_targets = parent_targets.iter().map(|(value, child)| {
162                let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
163                else {
164                    unreachable!()
165                };
166                (value, targets.target_for_value(value))
167            });
168            // The otherwise either is the same target branch or an unreachable.
169            let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
170
171            // Create `bbEq` in example above
172            let eq_switch = BasicBlockData::new(
173                Some(Terminator {
174                    source_info: bbs[parent].terminator().source_info,
175                    kind: TerminatorKind::SwitchInt {
176                        // switch on the first discriminant, so we can mark the second one as dead
177                        discr: parent_op,
178                        targets: eq_targets,
179                    },
180                }),
181                bbs[parent].is_cleanup,
182            );
183
184            let eq_bb = patch.new_block(eq_switch);
185
186            // Jump to it on the basis of the inequality comparison
187            let true_case = opt_data.destination;
188            let false_case = eq_bb;
189            patch.patch_terminator(
190                parent,
191                TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
192            );
193
194            patch.apply(body);
195        }
196
197        // Since this optimization adds new basic blocks and invalidates others,
198        // clean up the cfg to make it nicer for other passes
199        if should_cleanup {
200            simplify_cfg(tcx, body);
201        }
202    }
203
204    fn is_required(&self) -> bool {
205        false
206    }
207}
208
209#[derive(Debug)]
210struct OptimizationData<'tcx> {
211    destination: BasicBlock,
212    child_place: Place<'tcx>,
213    child_ty: Ty<'tcx>,
214    child_source: SourceInfo,
215    need_hoist_discriminant: bool,
216}
217
218fn evaluate_candidate<'tcx>(
219    tcx: TyCtxt<'tcx>,
220    body: &Body<'tcx>,
221    parent: BasicBlock,
222) -> Option<OptimizationData<'tcx>> {
223    let bbs = &body.basic_blocks;
224    // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
225    if bbs[parent].is_cleanup {
226        return None;
227    }
228    let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
229    else {
230        return None;
231    };
232    let parent_ty = parent_discr.ty(body.local_decls(), tcx);
233    let (_, child) = targets.iter().next()?;
234
235    let Terminator {
236        kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
237        source_info,
238    } = bbs[child].terminator()
239    else {
240        return None;
241    };
242    let child_ty = child_discr.ty(body.local_decls(), tcx);
243    if child_ty != parent_ty {
244        return None;
245    }
246
247    // We only handle:
248    // ```
249    // bb4: {
250    //     _8 = discriminant((_3.1: Enum1));
251    //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
252    // }
253    // ```
254    // and
255    // ```
256    // bb2: {
257    //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
258    // }
259    // ```
260    if bbs[child].statements.len() > 1 {
261        return None;
262    }
263
264    // When thie BB has exactly one statement, this statement should be discriminant.
265    let need_hoist_discriminant = bbs[child].statements.len() == 1;
266    let child_place = if need_hoist_discriminant {
267        if !bbs[targets.otherwise()].is_empty_unreachable() {
268            // Someone could write code like this:
269            // ```rust
270            // let Q = val;
271            // if discriminant(P) == otherwise {
272            //     let ptr = &mut Q as *mut _ as *mut u8;
273            //     // It may be difficult for us to effectively determine whether values are valid.
274            //     // Invalid values can come from all sorts of corners.
275            //     unsafe { *ptr = 10; }
276            // }
277            //
278            // match P {
279            //    A => match Q {
280            //        A => {
281            //            // code
282            //        }
283            //        _ => {
284            //            // don't use Q
285            //        }
286            //    }
287            //    _ => {
288            //        // don't use Q
289            //    }
290            // };
291            // ```
292            //
293            // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
294            // invalid value, which is UB.
295            // In order to fix this, **we would either need to show that the discriminant computation of
296            // `place` is computed in all branches**.
297            // FIXME(#95162) For the moment, we adopt a conservative approach and
298            // consider only the `otherwise` branch has no statements and an unreachable terminator.
299            return None;
300        }
301        // Handle:
302        // ```
303        // bb4: {
304        //     _8 = discriminant((_3.1: Enum1));
305        //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
306        // }
307        // ```
308        let [
309            Statement {
310                kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
311                ..
312            },
313        ] = bbs[child].statements.as_slice()
314        else {
315            return None;
316        };
317        *child_place
318    } else {
319        // Handle:
320        // ```
321        // bb2: {
322        //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
323        // }
324        // ```
325        let Operand::Copy(child_place) = child_discr else {
326            return None;
327        };
328        *child_place
329    };
330    let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
331    {
332        child_targets.otherwise()
333    } else {
334        targets.otherwise()
335    };
336
337    // Verify that the optimization is legal for each branch
338    for (value, child) in targets.iter() {
339        if !verify_candidate_branch(
340            &bbs[child],
341            value,
342            child_place,
343            destination,
344            need_hoist_discriminant,
345        ) {
346            return None;
347        }
348    }
349    Some(OptimizationData {
350        destination,
351        child_place,
352        child_ty,
353        child_source: *source_info,
354        need_hoist_discriminant,
355    })
356}
357
358fn verify_candidate_branch<'tcx>(
359    branch: &BasicBlockData<'tcx>,
360    value: u128,
361    place: Place<'tcx>,
362    destination: BasicBlock,
363    need_hoist_discriminant: bool,
364) -> bool {
365    // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
366    let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
367        return false;
368    };
369    if need_hoist_discriminant {
370        // If we need hoist discriminant, the branch must have exactly one statement.
371        let [statement] = branch.statements.as_slice() else {
372            return false;
373        };
374        // The statement must assign the discriminant of `place`.
375        let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
376            statement.kind
377        else {
378            return false;
379        };
380        if from_place != place {
381            return false;
382        }
383        // The assignment must invalidate a local that terminate on a `SwitchInt`.
384        if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
385            return false;
386        }
387    } else {
388        // If we don't need hoist discriminant, the branch must not have any statements.
389        if !branch.statements.is_empty() {
390            return false;
391        }
392        // The place on `SwitchInt` must be the same.
393        if *switch_op != Operand::Copy(place) {
394            return false;
395        }
396    }
397    // It must fall through to `destination` if the switch misses.
398    if destination != targets.otherwise() {
399        return false;
400    }
401    // It must have exactly one branch for value `value` and have no more branches.
402    let mut iter = targets.iter();
403    let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
404        return false;
405    };
406    target_value == value
407}