rustc_mir_transform/
match_branches.rs

1use std::iter;
2
3use rustc_abi::Integer;
4use rustc_index::IndexSlice;
5use rustc_middle::mir::*;
6use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
7use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
8use tracing::instrument;
9
10use super::simplify::simplify_cfg;
11use crate::patch::MirPatch;
12
13pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17        sess.mir_opt_level() >= 1
18    }
19
20    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
21        let typing_env = body.typing_env(tcx);
22        let mut should_cleanup = false;
23        for i in 0..body.basic_blocks.len() {
24            let bbs = &*body.basic_blocks;
25            let bb_idx = BasicBlock::from_usize(i);
26            match bbs[bb_idx].terminator().kind {
27                TerminatorKind::SwitchInt {
28                    discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)),
29                    ref targets,
30                    ..
31                    // We require that the possible target blocks don't contain this block.
32                } if !targets.all_targets().contains(&bb_idx) => {}
33                // Only optimize switch int statements
34                _ => continue,
35            };
36
37            if SimplifyToIf.simplify(tcx, body, bb_idx, typing_env).is_some() {
38                should_cleanup = true;
39                continue;
40            }
41            if SimplifyToExp::default().simplify(tcx, body, bb_idx, typing_env).is_some() {
42                should_cleanup = true;
43                continue;
44            }
45        }
46
47        if should_cleanup {
48            simplify_cfg(body);
49        }
50    }
51
52    fn is_required(&self) -> bool {
53        false
54    }
55}
56
57trait SimplifyMatch<'tcx> {
58    /// Simplifies a match statement, returning `Some` if the simplification succeeds, `None`
59    /// otherwise. Generic code is written here, and we generally don't need a custom
60    /// implementation.
61    fn simplify(
62        &mut self,
63        tcx: TyCtxt<'tcx>,
64        body: &mut Body<'tcx>,
65        switch_bb_idx: BasicBlock,
66        typing_env: ty::TypingEnv<'tcx>,
67    ) -> Option<()> {
68        let bbs = &body.basic_blocks;
69        let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
70            TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
71            _ => unreachable!(),
72        };
73
74        let discr_ty = discr.ty(body.local_decls(), tcx);
75        self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?;
76
77        let mut patch = MirPatch::new(body);
78
79        // Take ownership of items now that we know we can optimize.
80        let discr = discr.clone();
81
82        // Introduce a temporary for the discriminant value.
83        let source_info = bbs[switch_bb_idx].terminator().source_info;
84        let discr_local = patch.new_temp(discr_ty, source_info.span);
85
86        let (_, first) = targets.iter().next().unwrap();
87        let statement_index = bbs[switch_bb_idx].statements.len();
88        let parent_end = Location { block: switch_bb_idx, statement_index };
89        patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
90        patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
91        self.new_stmts(
92            tcx,
93            targets,
94            typing_env,
95            &mut patch,
96            parent_end,
97            bbs,
98            discr_local,
99            discr_ty,
100        );
101        patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
102        patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone());
103        patch.apply(body);
104        Some(())
105    }
106
107    /// Check that the BBs to be simplified satisfies all distinct and
108    /// that the terminator are the same.
109    /// There are also conditions for different ways of simplification.
110    fn can_simplify(
111        &mut self,
112        tcx: TyCtxt<'tcx>,
113        targets: &SwitchTargets,
114        typing_env: ty::TypingEnv<'tcx>,
115        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
116        discr_ty: Ty<'tcx>,
117    ) -> Option<()>;
118
119    fn new_stmts(
120        &self,
121        tcx: TyCtxt<'tcx>,
122        targets: &SwitchTargets,
123        typing_env: ty::TypingEnv<'tcx>,
124        patch: &mut MirPatch<'tcx>,
125        parent_end: Location,
126        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
127        discr_local: Local,
128        discr_ty: Ty<'tcx>,
129    );
130}
131
132struct SimplifyToIf;
133
134/// If a source block is found that switches between two blocks that are exactly
135/// the same modulo const bool assignments (e.g., one assigns true another false
136/// to the same place), merge a target block statements into the source block,
137/// using Eq / Ne comparison with switch value where const bools value differ.
138///
139/// For example:
140///
141/// ```ignore (MIR)
142/// bb0: {
143///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
144/// }
145///
146/// bb1: {
147///     _2 = const true;
148///     goto -> bb3;
149/// }
150///
151/// bb2: {
152///     _2 = const false;
153///     goto -> bb3;
154/// }
155/// ```
156///
157/// into:
158///
159/// ```ignore (MIR)
160/// bb0: {
161///    _2 = Eq(move _3, const 42_isize);
162///    goto -> bb3;
163/// }
164/// ```
165impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
166    #[instrument(level = "debug", skip(self, tcx), ret)]
167    fn can_simplify(
168        &mut self,
169        tcx: TyCtxt<'tcx>,
170        targets: &SwitchTargets,
171        typing_env: ty::TypingEnv<'tcx>,
172        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
173        _discr_ty: Ty<'tcx>,
174    ) -> Option<()> {
175        let (first, second) = match targets.all_targets() {
176            &[first, otherwise] => (first, otherwise),
177            &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second),
178            _ => {
179                return None;
180            }
181        };
182
183        // We require that the possible target blocks all be distinct.
184        if first == second {
185            return None;
186        }
187        // Check that destinations are identical, and if not, then don't optimize this block
188        if bbs[first].terminator().kind != bbs[second].terminator().kind {
189            return None;
190        }
191
192        // Check that blocks are assignments of consts to the same place or same statement,
193        // and match up 1-1, if not don't optimize this block.
194        let first_stmts = &bbs[first].statements;
195        let second_stmts = &bbs[second].statements;
196        if first_stmts.len() != second_stmts.len() {
197            return None;
198        }
199        for (f, s) in iter::zip(first_stmts, second_stmts) {
200            match (&f.kind, &s.kind) {
201                // If two statements are exactly the same, we can optimize.
202                (f_s, s_s) if f_s == s_s => {}
203
204                // If two statements are const bool assignments to the same place, we can optimize.
205                (
206                    StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
207                    StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
208                ) if lhs_f == lhs_s
209                    && f_c.const_.ty().is_bool()
210                    && s_c.const_.ty().is_bool()
211                    && f_c.const_.try_eval_bool(tcx, typing_env).is_some()
212                    && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {}
213
214                // Otherwise we cannot optimize. Try another block.
215                _ => return None,
216            }
217        }
218        Some(())
219    }
220
221    fn new_stmts(
222        &self,
223        tcx: TyCtxt<'tcx>,
224        targets: &SwitchTargets,
225        typing_env: ty::TypingEnv<'tcx>,
226        patch: &mut MirPatch<'tcx>,
227        parent_end: Location,
228        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
229        discr_local: Local,
230        discr_ty: Ty<'tcx>,
231    ) {
232        let ((val, first), second) = match (targets.all_targets(), targets.all_values()) {
233            (&[first, otherwise], &[val]) => ((val, first), otherwise),
234            (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => {
235                ((val, first), second)
236            }
237            _ => unreachable!(),
238        };
239
240        // We already checked that first and second are different blocks,
241        // and bb_idx has a different terminator from both of them.
242        let first = &bbs[first];
243        let second = &bbs[second];
244        for (f, s) in iter::zip(&first.statements, &second.statements) {
245            match (&f.kind, &s.kind) {
246                (f_s, s_s) if f_s == s_s => {
247                    patch.add_statement(parent_end, f.kind.clone());
248                }
249
250                (
251                    StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
252                    StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))),
253                ) => {
254                    // From earlier loop we know that we are dealing with bool constants only:
255                    let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap();
256                    let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap();
257                    if f_b == s_b {
258                        // Same value in both blocks. Use statement as is.
259                        patch.add_statement(parent_end, f.kind.clone());
260                    } else {
261                        // Different value between blocks. Make value conditional on switch
262                        // condition.
263                        let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size;
264                        let const_cmp = Operand::const_from_scalar(
265                            tcx,
266                            discr_ty,
267                            rustc_const_eval::interpret::Scalar::from_uint(val, size),
268                            rustc_span::DUMMY_SP,
269                        );
270                        let op = if f_b { BinOp::Eq } else { BinOp::Ne };
271                        let rhs = Rvalue::BinaryOp(
272                            op,
273                            Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)),
274                        );
275                        patch.add_assign(parent_end, *lhs, rhs);
276                    }
277                }
278
279                _ => unreachable!(),
280            }
281        }
282    }
283}
284
285/// Check if the cast constant using `IntToInt` is equal to the target constant.
286fn can_cast(
287    tcx: TyCtxt<'_>,
288    src_val: impl Into<u128>,
289    src_layout: TyAndLayout<'_>,
290    cast_ty: Ty<'_>,
291    target_scalar: ScalarInt,
292) -> bool {
293    let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
294    let v = match src_layout.ty.kind() {
295        ty::Uint(_) => from_scalar.to_uint(src_layout.size),
296        ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
297        _ => unreachable!("invalid int"),
298    };
299    let size = match *cast_ty.kind() {
300        ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
301        ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
302        _ => unreachable!("invalid int"),
303    };
304    let v = size.truncate(v);
305    let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
306    cast_scalar == target_scalar
307}
308
309#[derive(Default)]
310struct SimplifyToExp {
311    transform_kinds: Vec<TransformKind>,
312}
313
314#[derive(Clone, Copy, Debug)]
315enum ExpectedTransformKind<'a, 'tcx> {
316    /// Identical statements.
317    Same(&'a StatementKind<'tcx>),
318    /// Assignment statements have the same value.
319    SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt },
320    /// Enum variant comparison type.
321    Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> },
322}
323
324enum TransformKind {
325    Same,
326    Cast,
327}
328
329impl From<ExpectedTransformKind<'_, '_>> for TransformKind {
330    fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self {
331        match compare_type {
332            ExpectedTransformKind::Same(_) => TransformKind::Same,
333            ExpectedTransformKind::SameByEq { .. } => TransformKind::Same,
334            ExpectedTransformKind::Cast { .. } => TransformKind::Cast,
335        }
336    }
337}
338
339/// If we find that the value of match is the same as the assignment,
340/// merge a target block statements into the source block,
341/// using cast to transform different integer types.
342///
343/// For example:
344///
345/// ```ignore (MIR)
346/// bb0: {
347///     switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
348/// }
349///
350/// bb1: {
351///     unreachable;
352/// }
353///
354/// bb2: {
355///     _0 = const 1_i16;
356///     goto -> bb5;
357/// }
358///
359/// bb3: {
360///     _0 = const 2_i16;
361///     goto -> bb5;
362/// }
363///
364/// bb4: {
365///     _0 = const 3_i16;
366///     goto -> bb5;
367/// }
368/// ```
369///
370/// into:
371///
372/// ```ignore (MIR)
373/// bb0: {
374///    _0 = _3 as i16 (IntToInt);
375///    goto -> bb5;
376/// }
377/// ```
378impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
379    #[instrument(level = "debug", skip(self, tcx), ret)]
380    fn can_simplify(
381        &mut self,
382        tcx: TyCtxt<'tcx>,
383        targets: &SwitchTargets,
384        typing_env: ty::TypingEnv<'tcx>,
385        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
386        discr_ty: Ty<'tcx>,
387    ) -> Option<()> {
388        if targets.iter().len() < 2 || targets.iter().len() > 64 {
389            return None;
390        }
391        // We require that the possible target blocks all be distinct.
392        if !targets.is_distinct() {
393            return None;
394        }
395        if !bbs[targets.otherwise()].is_empty_unreachable() {
396            return None;
397        }
398        let mut target_iter = targets.iter();
399        let (first_case_val, first_target) = target_iter.next().unwrap();
400        let first_terminator_kind = &bbs[first_target].terminator().kind;
401        // Check that destinations are identical, and if not, then don't optimize this block
402        if !targets
403            .iter()
404            .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
405        {
406            return None;
407        }
408
409        let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap();
410        let first_stmts = &bbs[first_target].statements;
411        let (second_case_val, second_target) = target_iter.next().unwrap();
412        let second_stmts = &bbs[second_target].statements;
413        if first_stmts.len() != second_stmts.len() {
414            return None;
415        }
416
417        // We first compare the two branches, and then the other branches need to fulfill the same
418        // conditions.
419        let mut expected_transform_kinds = Vec::new();
420        for (f, s) in iter::zip(first_stmts, second_stmts) {
421            let compare_type = match (&f.kind, &s.kind) {
422                // If two statements are exactly the same, we can optimize.
423                (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s),
424
425                // If two statements are assignments with the match values to the same place, we
426                // can optimize.
427                (
428                    StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
429                    StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
430                ) if lhs_f == lhs_s
431                    && f_c.const_.ty() == s_c.const_.ty()
432                    && f_c.const_.ty().is_integral() =>
433                {
434                    match (
435                        f_c.const_.try_eval_scalar_int(tcx, typing_env),
436                        s_c.const_.try_eval_scalar_int(tcx, typing_env),
437                    ) {
438                        (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq {
439                            place: lhs_f,
440                            ty: f_c.const_.ty(),
441                            scalar: f,
442                        },
443                        // Enum variants can also be simplified to an assignment statement,
444                        // if we can use `IntToInt` cast to get an equal value.
445                        (Some(f), Some(s))
446                            if (can_cast(
447                                tcx,
448                                first_case_val,
449                                discr_layout,
450                                f_c.const_.ty(),
451                                f,
452                            ) && can_cast(
453                                tcx,
454                                second_case_val,
455                                discr_layout,
456                                f_c.const_.ty(),
457                                s,
458                            )) =>
459                        {
460                            ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() }
461                        }
462                        _ => {
463                            return None;
464                        }
465                    }
466                }
467
468                // Otherwise we cannot optimize. Try another block.
469                _ => return None,
470            };
471            expected_transform_kinds.push(compare_type);
472        }
473
474        // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
475        for (other_val, other_target) in target_iter {
476            let other_stmts = &bbs[other_target].statements;
477            if expected_transform_kinds.len() != other_stmts.len() {
478                return None;
479            }
480            for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) {
481                match (*f, &s.kind) {
482                    (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {}
483                    (
484                        ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar },
485                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
486                    ) if lhs_f == lhs_s
487                        && s_c.const_.ty() == f_ty
488                        && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {}
489                    (
490                        ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty },
491                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
492                    ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env)
493                        && lhs_f == lhs_s
494                        && s_c.const_.ty() == f_ty
495                        && can_cast(tcx, other_val, discr_layout, f_ty, f) => {}
496                    _ => return None,
497                }
498            }
499        }
500        self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect();
501        Some(())
502    }
503
504    fn new_stmts(
505        &self,
506        _tcx: TyCtxt<'tcx>,
507        targets: &SwitchTargets,
508        _typing_env: ty::TypingEnv<'tcx>,
509        patch: &mut MirPatch<'tcx>,
510        parent_end: Location,
511        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
512        discr_local: Local,
513        discr_ty: Ty<'tcx>,
514    ) {
515        let (_, first) = targets.iter().next().unwrap();
516        let first = &bbs[first];
517
518        for (t, s) in iter::zip(&self.transform_kinds, &first.statements) {
519            match (t, &s.kind) {
520                (TransformKind::Same, _) => {
521                    patch.add_statement(parent_end, s.kind.clone());
522                }
523                (
524                    TransformKind::Cast,
525                    StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
526                ) => {
527                    let operand = Operand::Copy(Place::from(discr_local));
528                    let r_val = if f_c.const_.ty() == discr_ty {
529                        Rvalue::Use(operand)
530                    } else {
531                        Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
532                    };
533                    patch.add_assign(parent_end, *lhs, r_val);
534                }
535                _ => unreachable!(),
536            }
537        }
538    }
539}