rustc_mir_transform/
match_branches.rs

1use std::iter;
2
3use rustc_abi::Integer;
4use rustc_index::IndexSlice;
5use rustc_middle::mir::*;
6use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
7use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
8use tracing::instrument;
9
10use super::simplify::simplify_cfg;
11use crate::patch::MirPatch;
12
13pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17        sess.mir_opt_level() >= 1
18    }
19
20    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
21        let typing_env = body.typing_env(tcx);
22        let mut apply_patch = false;
23        let mut patch = MirPatch::new(body);
24        for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
25            match &bb_data.terminator().kind {
26                TerminatorKind::SwitchInt {
27                    discr: Operand::Copy(_) | Operand::Move(_),
28                    targets,
29                    ..
30                    // We require that the possible target blocks don't contain this block.
31                } if !targets.all_targets().contains(&bb) => {}
32                // Only optimize switch int statements
33                _ => continue,
34            };
35
36            if SimplifyToIf.simplify(tcx, body, &mut patch, bb, typing_env).is_some() {
37                apply_patch = true;
38                continue;
39            }
40            if SimplifyToExp::default().simplify(tcx, body, &mut patch, bb, typing_env).is_some() {
41                apply_patch = true;
42                continue;
43            }
44        }
45
46        if apply_patch {
47            patch.apply(body);
48            simplify_cfg(tcx, body);
49        }
50    }
51
52    fn is_required(&self) -> bool {
53        false
54    }
55}
56
57trait SimplifyMatch<'tcx> {
58    /// Simplifies a match statement, returning `Some` if the simplification succeeds, `None`
59    /// otherwise. Generic code is written here, and we generally don't need a custom
60    /// implementation.
61    fn simplify(
62        &mut self,
63        tcx: TyCtxt<'tcx>,
64        body: &Body<'tcx>,
65        patch: &mut MirPatch<'tcx>,
66        switch_bb_idx: BasicBlock,
67        typing_env: ty::TypingEnv<'tcx>,
68    ) -> Option<()> {
69        let bbs = &body.basic_blocks;
70        let TerminatorKind::SwitchInt { discr, targets, .. } =
71            &bbs[switch_bb_idx].terminator().kind
72        else {
73            unreachable!();
74        };
75
76        let discr_ty = discr.ty(body.local_decls(), tcx);
77        self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?;
78
79        // Take ownership of items now that we know we can optimize.
80        let discr = discr.clone();
81
82        // Introduce a temporary for the discriminant value.
83        let source_info = bbs[switch_bb_idx].terminator().source_info;
84        let discr_local = patch.new_temp(discr_ty, source_info.span);
85
86        let (_, first) = targets.iter().next().unwrap();
87        let statement_index = bbs[switch_bb_idx].statements.len();
88        let parent_end = Location { block: switch_bb_idx, statement_index };
89        patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
90        patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
91        self.new_stmts(tcx, targets, typing_env, patch, parent_end, bbs, discr_local, discr_ty);
92        patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
93        patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone());
94        Some(())
95    }
96
97    /// Check that the BBs to be simplified satisfies all distinct and
98    /// that the terminator are the same.
99    /// There are also conditions for different ways of simplification.
100    fn can_simplify(
101        &mut self,
102        tcx: TyCtxt<'tcx>,
103        targets: &SwitchTargets,
104        typing_env: ty::TypingEnv<'tcx>,
105        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
106        discr_ty: Ty<'tcx>,
107    ) -> Option<()>;
108
109    fn new_stmts(
110        &self,
111        tcx: TyCtxt<'tcx>,
112        targets: &SwitchTargets,
113        typing_env: ty::TypingEnv<'tcx>,
114        patch: &mut MirPatch<'tcx>,
115        parent_end: Location,
116        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
117        discr_local: Local,
118        discr_ty: Ty<'tcx>,
119    );
120}
121
122struct SimplifyToIf;
123
124/// If a source block is found that switches between two blocks that are exactly
125/// the same modulo const bool assignments (e.g., one assigns true another false
126/// to the same place), merge a target block statements into the source block,
127/// using Eq / Ne comparison with switch value where const bools value differ.
128///
129/// For example:
130///
131/// ```ignore (MIR)
132/// bb0: {
133///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
134/// }
135///
136/// bb1: {
137///     _2 = const true;
138///     goto -> bb3;
139/// }
140///
141/// bb2: {
142///     _2 = const false;
143///     goto -> bb3;
144/// }
145/// ```
146///
147/// into:
148///
149/// ```ignore (MIR)
150/// bb0: {
151///    _2 = Eq(move _3, const 42_isize);
152///    goto -> bb3;
153/// }
154/// ```
155impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
156    #[instrument(level = "debug", skip(self, tcx), ret)]
157    fn can_simplify(
158        &mut self,
159        tcx: TyCtxt<'tcx>,
160        targets: &SwitchTargets,
161        typing_env: ty::TypingEnv<'tcx>,
162        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
163        _discr_ty: Ty<'tcx>,
164    ) -> Option<()> {
165        let (first, second) = match targets.all_targets() {
166            &[first, otherwise] => (first, otherwise),
167            &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second),
168            _ => {
169                return None;
170            }
171        };
172
173        // We require that the possible target blocks all be distinct.
174        if first == second {
175            return None;
176        }
177        // Check that destinations are identical, and if not, then don't optimize this block
178        if bbs[first].terminator().kind != bbs[second].terminator().kind {
179            return None;
180        }
181
182        // Check that blocks are assignments of consts to the same place or same statement,
183        // and match up 1-1, if not don't optimize this block.
184        let first_stmts = &bbs[first].statements;
185        let second_stmts = &bbs[second].statements;
186        if first_stmts.len() != second_stmts.len() {
187            return None;
188        }
189        for (f, s) in iter::zip(first_stmts, second_stmts) {
190            match (&f.kind, &s.kind) {
191                // If two statements are exactly the same, we can optimize.
192                (f_s, s_s) if f_s == s_s => {}
193
194                // If two statements are const bool assignments to the same place, we can optimize.
195                (
196                    StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
197                    StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
198                ) if lhs_f == lhs_s
199                    && f_c.const_.ty().is_bool()
200                    && s_c.const_.ty().is_bool()
201                    && f_c.const_.try_eval_bool(tcx, typing_env).is_some()
202                    && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {}
203
204                // Otherwise we cannot optimize. Try another block.
205                _ => return None,
206            }
207        }
208        Some(())
209    }
210
211    fn new_stmts(
212        &self,
213        tcx: TyCtxt<'tcx>,
214        targets: &SwitchTargets,
215        typing_env: ty::TypingEnv<'tcx>,
216        patch: &mut MirPatch<'tcx>,
217        parent_end: Location,
218        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
219        discr_local: Local,
220        discr_ty: Ty<'tcx>,
221    ) {
222        let ((val, first), second) = match (targets.all_targets(), targets.all_values()) {
223            (&[first, otherwise], &[val]) => ((val, first), otherwise),
224            (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => {
225                ((val, first), second)
226            }
227            _ => unreachable!(),
228        };
229
230        // We already checked that first and second are different blocks,
231        // and bb_idx has a different terminator from both of them.
232        let first = &bbs[first];
233        let second = &bbs[second];
234        for (f, s) in iter::zip(&first.statements, &second.statements) {
235            match (&f.kind, &s.kind) {
236                (f_s, s_s) if f_s == s_s => {
237                    patch.add_statement(parent_end, f.kind.clone());
238                }
239
240                (
241                    StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
242                    StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))),
243                ) => {
244                    // From earlier loop we know that we are dealing with bool constants only:
245                    let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap();
246                    let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap();
247                    if f_b == s_b {
248                        // Same value in both blocks. Use statement as is.
249                        patch.add_statement(parent_end, f.kind.clone());
250                    } else {
251                        // Different value between blocks. Make value conditional on switch
252                        // condition.
253                        let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size;
254                        let const_cmp = Operand::const_from_scalar(
255                            tcx,
256                            discr_ty,
257                            rustc_const_eval::interpret::Scalar::from_uint(val, size),
258                            rustc_span::DUMMY_SP,
259                        );
260                        let op = if f_b { BinOp::Eq } else { BinOp::Ne };
261                        let rhs = Rvalue::BinaryOp(
262                            op,
263                            Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)),
264                        );
265                        patch.add_assign(parent_end, *lhs, rhs);
266                    }
267                }
268
269                _ => unreachable!(),
270            }
271        }
272    }
273}
274
275/// Check if the cast constant using `IntToInt` is equal to the target constant.
276fn can_cast(
277    tcx: TyCtxt<'_>,
278    src_val: impl Into<u128>,
279    src_layout: TyAndLayout<'_>,
280    cast_ty: Ty<'_>,
281    target_scalar: ScalarInt,
282) -> bool {
283    let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
284    let v = match src_layout.ty.kind() {
285        ty::Uint(_) => from_scalar.to_uint(src_layout.size),
286        ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
287        // We can also transform the values of other integer representations (such as char),
288        // although this may not be practical in real-world scenarios.
289        _ => return false,
290    };
291    let size = match *cast_ty.kind() {
292        ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
293        ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
294        _ => return false,
295    };
296    let v = size.truncate(v);
297    let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
298    cast_scalar == target_scalar
299}
300
301#[derive(Default)]
302struct SimplifyToExp {
303    transform_kinds: Vec<TransformKind>,
304}
305
306#[derive(Clone, Copy, Debug)]
307enum ExpectedTransformKind<'a, 'tcx> {
308    /// Identical statements.
309    Same(&'a StatementKind<'tcx>),
310    /// Assignment statements have the same value.
311    SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt },
312    /// Enum variant comparison type.
313    Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> },
314}
315
316enum TransformKind {
317    Same,
318    Cast,
319}
320
321impl From<ExpectedTransformKind<'_, '_>> for TransformKind {
322    fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self {
323        match compare_type {
324            ExpectedTransformKind::Same(_) => TransformKind::Same,
325            ExpectedTransformKind::SameByEq { .. } => TransformKind::Same,
326            ExpectedTransformKind::Cast { .. } => TransformKind::Cast,
327        }
328    }
329}
330
331/// If we find that the value of match is the same as the assignment,
332/// merge a target block statements into the source block,
333/// using cast to transform different integer types.
334///
335/// For example:
336///
337/// ```ignore (MIR)
338/// bb0: {
339///     switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
340/// }
341///
342/// bb1: {
343///     unreachable;
344/// }
345///
346/// bb2: {
347///     _0 = const 1_i16;
348///     goto -> bb5;
349/// }
350///
351/// bb3: {
352///     _0 = const 2_i16;
353///     goto -> bb5;
354/// }
355///
356/// bb4: {
357///     _0 = const 3_i16;
358///     goto -> bb5;
359/// }
360/// ```
361///
362/// into:
363///
364/// ```ignore (MIR)
365/// bb0: {
366///    _0 = _3 as i16 (IntToInt);
367///    goto -> bb5;
368/// }
369/// ```
370impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
371    #[instrument(level = "debug", skip(self, tcx), ret)]
372    fn can_simplify(
373        &mut self,
374        tcx: TyCtxt<'tcx>,
375        targets: &SwitchTargets,
376        typing_env: ty::TypingEnv<'tcx>,
377        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
378        discr_ty: Ty<'tcx>,
379    ) -> Option<()> {
380        if targets.iter().len() < 2 || targets.iter().len() > 64 {
381            return None;
382        }
383        // We require that the possible target blocks all be distinct.
384        if !targets.is_distinct() {
385            return None;
386        }
387        if !bbs[targets.otherwise()].is_empty_unreachable() {
388            return None;
389        }
390        let mut target_iter = targets.iter();
391        let (first_case_val, first_target) = target_iter.next().unwrap();
392        let first_terminator_kind = &bbs[first_target].terminator().kind;
393        // Check that destinations are identical, and if not, then don't optimize this block
394        if !targets
395            .iter()
396            .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
397        {
398            return None;
399        }
400
401        let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap();
402        let first_stmts = &bbs[first_target].statements;
403        let (second_case_val, second_target) = target_iter.next().unwrap();
404        let second_stmts = &bbs[second_target].statements;
405        if first_stmts.len() != second_stmts.len() {
406            return None;
407        }
408
409        // We first compare the two branches, and then the other branches need to fulfill the same
410        // conditions.
411        let mut expected_transform_kinds = Vec::new();
412        for (f, s) in iter::zip(first_stmts, second_stmts) {
413            let compare_type = match (&f.kind, &s.kind) {
414                // If two statements are exactly the same, we can optimize.
415                (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s),
416
417                // If two statements are assignments with the match values to the same place, we
418                // can optimize.
419                (
420                    StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
421                    StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
422                ) if lhs_f == lhs_s
423                    && f_c.const_.ty() == s_c.const_.ty()
424                    && f_c.const_.ty().is_integral() =>
425                {
426                    match (
427                        f_c.const_.try_eval_scalar_int(tcx, typing_env),
428                        s_c.const_.try_eval_scalar_int(tcx, typing_env),
429                    ) {
430                        (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq {
431                            place: lhs_f,
432                            ty: f_c.const_.ty(),
433                            scalar: f,
434                        },
435                        // Enum variants can also be simplified to an assignment statement,
436                        // if we can use `IntToInt` cast to get an equal value.
437                        (Some(f), Some(s))
438                            if (can_cast(
439                                tcx,
440                                first_case_val,
441                                discr_layout,
442                                f_c.const_.ty(),
443                                f,
444                            ) && can_cast(
445                                tcx,
446                                second_case_val,
447                                discr_layout,
448                                f_c.const_.ty(),
449                                s,
450                            )) =>
451                        {
452                            ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() }
453                        }
454                        _ => {
455                            return None;
456                        }
457                    }
458                }
459
460                // Otherwise we cannot optimize. Try another block.
461                _ => return None,
462            };
463            expected_transform_kinds.push(compare_type);
464        }
465
466        // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
467        for (other_val, other_target) in target_iter {
468            let other_stmts = &bbs[other_target].statements;
469            if expected_transform_kinds.len() != other_stmts.len() {
470                return None;
471            }
472            for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) {
473                match (*f, &s.kind) {
474                    (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {}
475                    (
476                        ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar },
477                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
478                    ) if lhs_f == lhs_s
479                        && s_c.const_.ty() == f_ty
480                        && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {}
481                    (
482                        ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty },
483                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
484                    ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env)
485                        && lhs_f == lhs_s
486                        && s_c.const_.ty() == f_ty
487                        && can_cast(tcx, other_val, discr_layout, f_ty, f) => {}
488                    _ => return None,
489                }
490            }
491        }
492        self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect();
493        Some(())
494    }
495
496    fn new_stmts(
497        &self,
498        _tcx: TyCtxt<'tcx>,
499        targets: &SwitchTargets,
500        _typing_env: ty::TypingEnv<'tcx>,
501        patch: &mut MirPatch<'tcx>,
502        parent_end: Location,
503        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
504        discr_local: Local,
505        discr_ty: Ty<'tcx>,
506    ) {
507        let (_, first) = targets.iter().next().unwrap();
508        let first = &bbs[first];
509
510        for (t, s) in iter::zip(&self.transform_kinds, &first.statements) {
511            match (t, &s.kind) {
512                (TransformKind::Same, _) => {
513                    patch.add_statement(parent_end, s.kind.clone());
514                }
515                (
516                    TransformKind::Cast,
517                    StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
518                ) => {
519                    let operand = Operand::Copy(Place::from(discr_local));
520                    let r_val = if f_c.const_.ty() == discr_ty {
521                        Rvalue::Use(operand)
522                    } else {
523                        Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
524                    };
525                    patch.add_assign(parent_end, *lhs, r_val);
526                }
527                _ => unreachable!(),
528            }
529        }
530    }
531}