1use std::iter;
2
3use rustc_abi::Integer;
4use rustc_index::IndexSlice;
5use rustc_middle::mir::*;
6use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
7use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
8use tracing::instrument;
9
10use super::simplify::simplify_cfg;
11use crate::patch::MirPatch;
12
13pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17 sess.mir_opt_level() >= 1
18 }
19
20 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
21 let typing_env = body.typing_env(tcx);
22 let mut apply_patch = false;
23 let mut patch = MirPatch::new(body);
24 for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
25 match &bb_data.terminator().kind {
26 TerminatorKind::SwitchInt {
27 discr: Operand::Copy(_) | Operand::Move(_),
28 targets,
29 ..
30 } if !targets.all_targets().contains(&bb) => {}
32 _ => continue,
34 };
35
36 if SimplifyToIf.simplify(tcx, body, &mut patch, bb, typing_env).is_some() {
37 apply_patch = true;
38 continue;
39 }
40 if SimplifyToExp::default().simplify(tcx, body, &mut patch, bb, typing_env).is_some() {
41 apply_patch = true;
42 continue;
43 }
44 }
45
46 if apply_patch {
47 patch.apply(body);
48 simplify_cfg(tcx, body);
49 }
50 }
51
52 fn is_required(&self) -> bool {
53 false
54 }
55}
56
57trait SimplifyMatch<'tcx> {
58 fn simplify(
62 &mut self,
63 tcx: TyCtxt<'tcx>,
64 body: &Body<'tcx>,
65 patch: &mut MirPatch<'tcx>,
66 switch_bb_idx: BasicBlock,
67 typing_env: ty::TypingEnv<'tcx>,
68 ) -> Option<()> {
69 let bbs = &body.basic_blocks;
70 let TerminatorKind::SwitchInt { discr, targets, .. } =
71 &bbs[switch_bb_idx].terminator().kind
72 else {
73 unreachable!();
74 };
75
76 let discr_ty = discr.ty(body.local_decls(), tcx);
77 self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?;
78
79 let discr = discr.clone();
81
82 let source_info = bbs[switch_bb_idx].terminator().source_info;
84 let discr_local = patch.new_temp(discr_ty, source_info.span);
85
86 let (_, first) = targets.iter().next().unwrap();
87 let statement_index = bbs[switch_bb_idx].statements.len();
88 let parent_end = Location { block: switch_bb_idx, statement_index };
89 patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
90 patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
91 self.new_stmts(tcx, targets, typing_env, patch, parent_end, bbs, discr_local, discr_ty);
92 patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
93 patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone());
94 Some(())
95 }
96
97 fn can_simplify(
101 &mut self,
102 tcx: TyCtxt<'tcx>,
103 targets: &SwitchTargets,
104 typing_env: ty::TypingEnv<'tcx>,
105 bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
106 discr_ty: Ty<'tcx>,
107 ) -> Option<()>;
108
109 fn new_stmts(
110 &self,
111 tcx: TyCtxt<'tcx>,
112 targets: &SwitchTargets,
113 typing_env: ty::TypingEnv<'tcx>,
114 patch: &mut MirPatch<'tcx>,
115 parent_end: Location,
116 bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
117 discr_local: Local,
118 discr_ty: Ty<'tcx>,
119 );
120}
121
122struct SimplifyToIf;
123
124impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
156 #[instrument(level = "debug", skip(self, tcx), ret)]
157 fn can_simplify(
158 &mut self,
159 tcx: TyCtxt<'tcx>,
160 targets: &SwitchTargets,
161 typing_env: ty::TypingEnv<'tcx>,
162 bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
163 _discr_ty: Ty<'tcx>,
164 ) -> Option<()> {
165 let (first, second) = match targets.all_targets() {
166 &[first, otherwise] => (first, otherwise),
167 &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second),
168 _ => {
169 return None;
170 }
171 };
172
173 if first == second {
175 return None;
176 }
177 if bbs[first].terminator().kind != bbs[second].terminator().kind {
179 return None;
180 }
181
182 let first_stmts = &bbs[first].statements;
185 let second_stmts = &bbs[second].statements;
186 if first_stmts.len() != second_stmts.len() {
187 return None;
188 }
189 for (f, s) in iter::zip(first_stmts, second_stmts) {
190 match (&f.kind, &s.kind) {
191 (f_s, s_s) if f_s == s_s => {}
193
194 (
196 StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
197 StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
198 ) if lhs_f == lhs_s
199 && f_c.const_.ty().is_bool()
200 && s_c.const_.ty().is_bool()
201 && f_c.const_.try_eval_bool(tcx, typing_env).is_some()
202 && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {}
203
204 _ => return None,
206 }
207 }
208 Some(())
209 }
210
211 fn new_stmts(
212 &self,
213 tcx: TyCtxt<'tcx>,
214 targets: &SwitchTargets,
215 typing_env: ty::TypingEnv<'tcx>,
216 patch: &mut MirPatch<'tcx>,
217 parent_end: Location,
218 bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
219 discr_local: Local,
220 discr_ty: Ty<'tcx>,
221 ) {
222 let ((val, first), second) = match (targets.all_targets(), targets.all_values()) {
223 (&[first, otherwise], &[val]) => ((val, first), otherwise),
224 (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => {
225 ((val, first), second)
226 }
227 _ => unreachable!(),
228 };
229
230 let first = &bbs[first];
233 let second = &bbs[second];
234 for (f, s) in iter::zip(&first.statements, &second.statements) {
235 match (&f.kind, &s.kind) {
236 (f_s, s_s) if f_s == s_s => {
237 patch.add_statement(parent_end, f.kind.clone());
238 }
239
240 (
241 StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
242 StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))),
243 ) => {
244 let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap();
246 let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap();
247 if f_b == s_b {
248 patch.add_statement(parent_end, f.kind.clone());
250 } else {
251 let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size;
254 let const_cmp = Operand::const_from_scalar(
255 tcx,
256 discr_ty,
257 rustc_const_eval::interpret::Scalar::from_uint(val, size),
258 rustc_span::DUMMY_SP,
259 );
260 let op = if f_b { BinOp::Eq } else { BinOp::Ne };
261 let rhs = Rvalue::BinaryOp(
262 op,
263 Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)),
264 );
265 patch.add_assign(parent_end, *lhs, rhs);
266 }
267 }
268
269 _ => unreachable!(),
270 }
271 }
272 }
273}
274
275fn can_cast(
277 tcx: TyCtxt<'_>,
278 src_val: impl Into<u128>,
279 src_layout: TyAndLayout<'_>,
280 cast_ty: Ty<'_>,
281 target_scalar: ScalarInt,
282) -> bool {
283 let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
284 let v = match src_layout.ty.kind() {
285 ty::Uint(_) => from_scalar.to_uint(src_layout.size),
286 ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
287 _ => return false,
290 };
291 let size = match *cast_ty.kind() {
292 ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
293 ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
294 _ => return false,
295 };
296 let v = size.truncate(v);
297 let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
298 cast_scalar == target_scalar
299}
300
301#[derive(Default)]
302struct SimplifyToExp {
303 transform_kinds: Vec<TransformKind>,
304}
305
306#[derive(Clone, Copy, Debug)]
307enum ExpectedTransformKind<'a, 'tcx> {
308 Same(&'a StatementKind<'tcx>),
310 SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt },
312 Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> },
314}
315
316enum TransformKind {
317 Same,
318 Cast,
319}
320
321impl From<ExpectedTransformKind<'_, '_>> for TransformKind {
322 fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self {
323 match compare_type {
324 ExpectedTransformKind::Same(_) => TransformKind::Same,
325 ExpectedTransformKind::SameByEq { .. } => TransformKind::Same,
326 ExpectedTransformKind::Cast { .. } => TransformKind::Cast,
327 }
328 }
329}
330
331impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
371 #[instrument(level = "debug", skip(self, tcx), ret)]
372 fn can_simplify(
373 &mut self,
374 tcx: TyCtxt<'tcx>,
375 targets: &SwitchTargets,
376 typing_env: ty::TypingEnv<'tcx>,
377 bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
378 discr_ty: Ty<'tcx>,
379 ) -> Option<()> {
380 if targets.iter().len() < 2 || targets.iter().len() > 64 {
381 return None;
382 }
383 if !targets.is_distinct() {
385 return None;
386 }
387 if !bbs[targets.otherwise()].is_empty_unreachable() {
388 return None;
389 }
390 let mut target_iter = targets.iter();
391 let (first_case_val, first_target) = target_iter.next().unwrap();
392 let first_terminator_kind = &bbs[first_target].terminator().kind;
393 if !targets
395 .iter()
396 .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
397 {
398 return None;
399 }
400
401 let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap();
402 let first_stmts = &bbs[first_target].statements;
403 let (second_case_val, second_target) = target_iter.next().unwrap();
404 let second_stmts = &bbs[second_target].statements;
405 if first_stmts.len() != second_stmts.len() {
406 return None;
407 }
408
409 let mut expected_transform_kinds = Vec::new();
412 for (f, s) in iter::zip(first_stmts, second_stmts) {
413 let compare_type = match (&f.kind, &s.kind) {
414 (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s),
416
417 (
420 StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
421 StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
422 ) if lhs_f == lhs_s
423 && f_c.const_.ty() == s_c.const_.ty()
424 && f_c.const_.ty().is_integral() =>
425 {
426 match (
427 f_c.const_.try_eval_scalar_int(tcx, typing_env),
428 s_c.const_.try_eval_scalar_int(tcx, typing_env),
429 ) {
430 (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq {
431 place: lhs_f,
432 ty: f_c.const_.ty(),
433 scalar: f,
434 },
435 (Some(f), Some(s))
438 if (can_cast(
439 tcx,
440 first_case_val,
441 discr_layout,
442 f_c.const_.ty(),
443 f,
444 ) && can_cast(
445 tcx,
446 second_case_val,
447 discr_layout,
448 f_c.const_.ty(),
449 s,
450 )) =>
451 {
452 ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() }
453 }
454 _ => {
455 return None;
456 }
457 }
458 }
459
460 _ => return None,
462 };
463 expected_transform_kinds.push(compare_type);
464 }
465
466 for (other_val, other_target) in target_iter {
468 let other_stmts = &bbs[other_target].statements;
469 if expected_transform_kinds.len() != other_stmts.len() {
470 return None;
471 }
472 for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) {
473 match (*f, &s.kind) {
474 (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {}
475 (
476 ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar },
477 StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
478 ) if lhs_f == lhs_s
479 && s_c.const_.ty() == f_ty
480 && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {}
481 (
482 ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty },
483 StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
484 ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env)
485 && lhs_f == lhs_s
486 && s_c.const_.ty() == f_ty
487 && can_cast(tcx, other_val, discr_layout, f_ty, f) => {}
488 _ => return None,
489 }
490 }
491 }
492 self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect();
493 Some(())
494 }
495
496 fn new_stmts(
497 &self,
498 _tcx: TyCtxt<'tcx>,
499 targets: &SwitchTargets,
500 _typing_env: ty::TypingEnv<'tcx>,
501 patch: &mut MirPatch<'tcx>,
502 parent_end: Location,
503 bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
504 discr_local: Local,
505 discr_ty: Ty<'tcx>,
506 ) {
507 let (_, first) = targets.iter().next().unwrap();
508 let first = &bbs[first];
509
510 for (t, s) in iter::zip(&self.transform_kinds, &first.statements) {
511 match (t, &s.kind) {
512 (TransformKind::Same, _) => {
513 patch.add_statement(parent_end, s.kind.clone());
514 }
515 (
516 TransformKind::Cast,
517 StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
518 ) => {
519 let operand = Operand::Copy(Place::from(discr_local));
520 let r_val = if f_c.const_.ty() == discr_ty {
521 Rvalue::Use(operand)
522 } else {
523 Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
524 };
525 patch.add_assign(parent_end, *lhs, r_val);
526 }
527 _ => unreachable!(),
528 }
529 }
530 }
531}