rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator,
11//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`.
12//!
13//! The algorithm maintains a set of replacement conditions:
14//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }`
15//!   if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`.
16//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }`
17//!   if assigning anything different from `value` to `place` turns the `SwitchInt`
18//!   into `Goto { target }`.
19//!
20//! In this file, we denote as `place ?= value` the existence of a replacement condition
21//! on `place` with given `value`, irrespective of the polarity and target of that
22//! replacement condition.
23//!
24//! We then walk the CFG backwards transforming the set of conditions.
25//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`.
26//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required.
27//!
28//! The optimization search can be very heavy, as it performs a DFS on MIR starting from
29//! each `SwitchInt` terminator. To manage the complexity, we:
30//! - bound the maximum depth by a constant `MAX_BACKTRACK`;
31//! - we only traverse `Goto` terminators.
32//!
33//! We try to avoid creating irreducible control-flow by not threading through a loop header.
34//!
35//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction
36//! cost by `MAX_COST`.
37
38use rustc_arena::DroplessArena;
39use rustc_const_eval::const_eval::DummyMachine;
40use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
41use rustc_data_structures::fx::FxHashSet;
42use rustc_index::IndexVec;
43use rustc_index::bit_set::DenseBitSet;
44use rustc_middle::bug;
45use rustc_middle::mir::interpret::Scalar;
46use rustc_middle::mir::visit::Visitor;
47use rustc_middle::mir::*;
48use rustc_middle::ty::{self, ScalarInt, TyCtxt};
49use rustc_mir_dataflow::lattice::HasBottom;
50use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
51use rustc_span::DUMMY_SP;
52use tracing::{debug, instrument, trace};
53
54use crate::cost_checker::CostChecker;
55
56pub(super) struct JumpThreading;
57
58const MAX_BACKTRACK: usize = 5;
59const MAX_COST: usize = 100;
60const MAX_PLACES: usize = 100;
61
62impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
63    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
64        sess.mir_opt_level() >= 2
65    }
66
67    #[instrument(skip_all level = "debug")]
68    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
69        let def_id = body.source.def_id();
70        debug!(?def_id);
71
72        // Optimizing coroutines creates query cycles.
73        if tcx.is_coroutine(def_id) {
74            trace!("Skipped for coroutine {:?}", def_id);
75            return;
76        }
77
78        let typing_env = body.typing_env(tcx);
79        let arena = &DroplessArena::default();
80        let mut finder = TOFinder {
81            tcx,
82            typing_env,
83            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
84            body,
85            arena,
86            map: Map::new(tcx, body, Some(MAX_PLACES)),
87            loop_headers: loop_headers(body),
88            opportunities: Vec::new(),
89        };
90
91        for (bb, _) in traversal::preorder(body) {
92            finder.start_from_switch(bb);
93        }
94
95        let opportunities = finder.opportunities;
96        debug!(?opportunities);
97        if opportunities.is_empty() {
98            return;
99        }
100
101        // Verify that we do not thread through a loop header.
102        for to in opportunities.iter() {
103            assert!(to.chain.iter().all(|&block| !finder.loop_headers.contains(block)));
104        }
105        OpportunitySet::new(body, opportunities).apply(body);
106    }
107
108    fn is_required(&self) -> bool {
109        false
110    }
111}
112
113#[derive(Debug)]
114struct ThreadingOpportunity {
115    /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`.
116    chain: Vec<BasicBlock>,
117    /// The `SwitchInt` will be replaced by `Goto { target }`.
118    target: BasicBlock,
119}
120
121struct TOFinder<'a, 'tcx> {
122    tcx: TyCtxt<'tcx>,
123    typing_env: ty::TypingEnv<'tcx>,
124    ecx: InterpCx<'tcx, DummyMachine>,
125    body: &'a Body<'tcx>,
126    map: Map<'tcx>,
127    loop_headers: DenseBitSet<BasicBlock>,
128    /// We use an arena to avoid cloning the slices when cloning `state`.
129    arena: &'a DroplessArena,
130    opportunities: Vec<ThreadingOpportunity>,
131}
132
133/// Represent the following statement. If we can prove that the current local is equal/not-equal
134/// to `value`, jump to `target`.
135#[derive(Copy, Clone, Debug)]
136struct Condition {
137    value: ScalarInt,
138    polarity: Polarity,
139    target: BasicBlock,
140}
141
142#[derive(Copy, Clone, Debug, Eq, PartialEq)]
143enum Polarity {
144    Ne,
145    Eq,
146}
147
148impl Condition {
149    fn matches(&self, value: ScalarInt) -> bool {
150        (self.value == value) == (self.polarity == Polarity::Eq)
151    }
152}
153
154#[derive(Copy, Clone, Debug)]
155struct ConditionSet<'a>(&'a [Condition]);
156
157impl HasBottom for ConditionSet<'_> {
158    const BOTTOM: Self = ConditionSet(&[]);
159
160    fn is_bottom(&self) -> bool {
161        self.0.is_empty()
162    }
163}
164
165impl<'a> ConditionSet<'a> {
166    fn iter(self) -> impl Iterator<Item = Condition> {
167        self.0.iter().copied()
168    }
169
170    fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> {
171        self.iter().filter(move |c| c.matches(value))
172    }
173
174    fn map(
175        self,
176        arena: &'a DroplessArena,
177        f: impl Fn(Condition) -> Option<Condition>,
178    ) -> Option<ConditionSet<'a>> {
179        let set = arena.try_alloc_from_iter(self.iter().map(|c| f(c).ok_or(()))).ok()?;
180        Some(ConditionSet(set))
181    }
182}
183
184impl<'a, 'tcx> TOFinder<'a, 'tcx> {
185    fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
186        state.all_bottom()
187    }
188
189    /// Recursion entry point to find threading opportunities.
190    #[instrument(level = "trace", skip(self))]
191    fn start_from_switch(&mut self, bb: BasicBlock) {
192        let bbdata = &self.body[bb];
193        if bbdata.is_cleanup || self.loop_headers.contains(bb) {
194            return;
195        }
196        let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
197        let Some(discr) = discr.place() else { return };
198        debug!(?discr, ?bb);
199
200        let discr_ty = discr.ty(self.body, self.tcx).ty;
201        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
202
203        let Some(discr) = self.map.find(discr.as_ref()) else { return };
204        debug!(?discr);
205
206        let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
207        let mut state = State::new_reachable();
208
209        let conds = if let Some((value, then, else_)) = targets.as_static_if() {
210            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
211            self.arena.alloc_from_iter([
212                Condition { value, polarity: Polarity::Eq, target: then },
213                Condition { value, polarity: Polarity::Ne, target: else_ },
214            ])
215        } else {
216            self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| {
217                let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
218                Some(Condition { value, polarity: Polarity::Eq, target })
219            }))
220        };
221        let conds = ConditionSet(conds);
222        state.insert_value_idx(discr, conds, &self.map);
223
224        self.find_opportunity(bb, state, cost, 0)
225    }
226
227    /// Recursively walk statements backwards from this bb's terminator to find threading
228    /// opportunities.
229    #[instrument(level = "trace", skip(self, cost), ret)]
230    fn find_opportunity(
231        &mut self,
232        bb: BasicBlock,
233        mut state: State<ConditionSet<'a>>,
234        mut cost: CostChecker<'_, 'tcx>,
235        depth: usize,
236    ) {
237        // Do not thread through loop headers.
238        if self.loop_headers.contains(bb) {
239            return;
240        }
241
242        debug!(cost = ?cost.cost());
243        for (statement_index, stmt) in
244            self.body.basic_blocks[bb].statements.iter().enumerate().rev()
245        {
246            if self.is_empty(&state) {
247                return;
248            }
249
250            cost.visit_statement(stmt, Location { block: bb, statement_index });
251            if cost.cost() > MAX_COST {
252                return;
253            }
254
255            // Attempt to turn the `current_condition` on `lhs` into a condition on another place.
256            self.process_statement(bb, stmt, &mut state);
257
258            // When a statement mutates a place, assignments to that place that happen
259            // above the mutation cannot fulfill a condition.
260            //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
261            //   _1 = 6
262            if let Some((lhs, tail)) = self.mutated_statement(stmt) {
263                state.flood_with_tail_elem(lhs.as_ref(), tail, &self.map, ConditionSet::BOTTOM);
264            }
265        }
266
267        if self.is_empty(&state) || depth >= MAX_BACKTRACK {
268            return;
269        }
270
271        let last_non_rec = self.opportunities.len();
272
273        let predecessors = &self.body.basic_blocks.predecessors()[bb];
274        if let &[pred] = &predecessors[..]
275            && bb != START_BLOCK
276        {
277            let term = self.body.basic_blocks[pred].terminator();
278            match term.kind {
279                TerminatorKind::SwitchInt { ref discr, ref targets } => {
280                    self.process_switch_int(discr, targets, bb, &mut state);
281                    self.find_opportunity(pred, state, cost, depth + 1);
282                }
283                _ => self.recurse_through_terminator(pred, || state, &cost, depth),
284            }
285        } else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
286            for &pred in predecessors {
287                self.recurse_through_terminator(pred, || state.clone(), &cost, depth);
288            }
289            self.recurse_through_terminator(last_pred, || state, &cost, depth);
290        }
291
292        let new_tos = &mut self.opportunities[last_non_rec..];
293        debug!(?new_tos);
294
295        // Try to deduplicate threading opportunities.
296        if new_tos.len() > 1
297            && new_tos.len() == predecessors.len()
298            && predecessors
299                .iter()
300                .zip(new_tos.iter())
301                .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target)
302        {
303            // All predecessors have a threading opportunity, and they all point to the same block.
304            debug!(?new_tos, "dedup");
305            let first = &mut new_tos[0];
306            *first = ThreadingOpportunity { chain: vec![bb], target: first.target };
307            self.opportunities.truncate(last_non_rec + 1);
308            return;
309        }
310
311        for op in self.opportunities[last_non_rec..].iter_mut() {
312            op.chain.push(bb);
313        }
314    }
315
316    /// Extract the mutated place from a statement.
317    ///
318    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
319    ///     (_1 as Ok).0 = _5;
320    ///     (_1 as Err).0 = _6;
321    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
322    /// the value may have been mangled by the second assignment.
323    ///
324    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
325    /// stop at flooding the discriminant, and preserve the variant fields.
326    ///     (_1 as Some).0 = _6;
327    ///     SetDiscriminant(_1, 1);
328    ///     switchInt((_1 as Some).0)
329    #[instrument(level = "trace", skip(self), ret)]
330    fn mutated_statement(
331        &self,
332        stmt: &Statement<'tcx>,
333    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
334        match stmt.kind {
335            StatementKind::Assign(box (place, _))
336            | StatementKind::Deinit(box place) => Some((place, None)),
337            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
338                Some((place, Some(TrackElem::Discriminant)))
339            }
340            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
341                Some((Place::from(local), None))
342            }
343            StatementKind::Retag(..)
344            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
345            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
346            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
347            | StatementKind::AscribeUserType(..)
348            | StatementKind::Coverage(..)
349            | StatementKind::FakeRead(..)
350            | StatementKind::ConstEvalCounter
351            | StatementKind::PlaceMention(..)
352            | StatementKind::BackwardIncompatibleDropHint { .. }
353            | StatementKind::Nop => None,
354        }
355    }
356
357    #[instrument(level = "trace", skip(self))]
358    fn process_immediate(
359        &mut self,
360        bb: BasicBlock,
361        lhs: PlaceIndex,
362        rhs: ImmTy<'tcx>,
363        state: &mut State<ConditionSet<'a>>,
364    ) {
365        let register_opportunity = |c: Condition| {
366            debug!(?bb, ?c.target, "register");
367            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
368        };
369
370        if let Some(conditions) = state.try_get_idx(lhs, &self.map)
371            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
372        {
373            conditions.iter_matches(int).for_each(register_opportunity);
374        }
375    }
376
377    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
378    #[instrument(level = "trace", skip(self))]
379    fn process_constant(
380        &mut self,
381        bb: BasicBlock,
382        lhs: PlaceIndex,
383        constant: OpTy<'tcx>,
384        state: &mut State<ConditionSet<'a>>,
385    ) {
386        self.map.for_each_projection_value(
387            lhs,
388            constant,
389            &mut |elem, op| match elem {
390                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
391                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
392                TrackElem::Discriminant => {
393                    let variant = self.ecx.read_discriminant(op).discard_err()?;
394                    let discr_value =
395                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
396                    Some(discr_value.into())
397                }
398                TrackElem::DerefLen => {
399                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
400                    let len_usize = op.len(&self.ecx).discard_err()?;
401                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
402                    Some(ImmTy::from_uint(len_usize, layout).into())
403                }
404            },
405            &mut |place, op| {
406                if let Some(conditions) = state.try_get_idx(place, &self.map)
407                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
408                    && let Some(imm) = imm.right()
409                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
410                {
411                    conditions.iter_matches(int).for_each(|c: Condition| {
412                        self.opportunities
413                            .push(ThreadingOpportunity { chain: vec![bb], target: c.target })
414                    })
415                }
416            },
417        );
418    }
419
420    #[instrument(level = "trace", skip(self))]
421    fn process_operand(
422        &mut self,
423        bb: BasicBlock,
424        lhs: PlaceIndex,
425        rhs: &Operand<'tcx>,
426        state: &mut State<ConditionSet<'a>>,
427    ) {
428        match rhs {
429            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
430            Operand::Constant(constant) => {
431                let Some(constant) =
432                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
433                else {
434                    return;
435                };
436                self.process_constant(bb, lhs, constant, state);
437            }
438            // Transfer the conditions on the copied rhs.
439            Operand::Move(rhs) | Operand::Copy(rhs) => {
440                let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
441                state.insert_place_idx(rhs, lhs, &self.map);
442            }
443        }
444    }
445
446    #[instrument(level = "trace", skip(self))]
447    fn process_assign(
448        &mut self,
449        bb: BasicBlock,
450        lhs_place: &Place<'tcx>,
451        rhs: &Rvalue<'tcx>,
452        state: &mut State<ConditionSet<'a>>,
453    ) {
454        let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
455        match rhs {
456            Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
457            // Transfer the conditions on the copy rhs.
458            Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
459            Rvalue::Discriminant(rhs) => {
460                let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
461                state.insert_place_idx(rhs, lhs, &self.map);
462            }
463            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
464            Rvalue::Aggregate(box kind, operands) => {
465                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
466                let lhs = match kind {
467                    // Do not support unions.
468                    AggregateKind::Adt(.., Some(_)) => return,
469                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
470                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
471                            && let Some(discr_value) = self
472                                .ecx
473                                .discriminant_for_variant(agg_ty, *variant_index)
474                                .discard_err()
475                        {
476                            self.process_immediate(bb, discr_target, discr_value, state);
477                        }
478                        if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
479                            idx
480                        } else {
481                            return;
482                        }
483                    }
484                    _ => lhs,
485                };
486                for (field_index, operand) in operands.iter_enumerated() {
487                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
488                        self.process_operand(bb, field, operand, state);
489                    }
490                }
491            }
492            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
493            Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
494                let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
495                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
496                let Some(place) = self.map.find(place.as_ref()) else { return };
497                let Some(conds) = conditions.map(self.arena, |mut cond| {
498                    cond.value = self
499                        .ecx
500                        .unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
501                        .discard_err()?
502                        .to_scalar_int()
503                        .discard_err()?;
504                    Some(cond)
505                }) else {
506                    return;
507                };
508                state.insert_value_idx(place, conds, &self.map);
509            }
510            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
511            // Create a condition on `rhs ?= B`.
512            Rvalue::BinaryOp(
513                op,
514                box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
515                | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
516            ) => {
517                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
518                let Some(place) = self.map.find(place.as_ref()) else { return };
519                let equals = match op {
520                    BinOp::Eq => ScalarInt::TRUE,
521                    BinOp::Ne => ScalarInt::FALSE,
522                    _ => return,
523                };
524                if value.const_.ty().is_floating_point() {
525                    // Floating point equality does not follow bit-patterns.
526                    // -0.0 and NaN both have special rules for equality,
527                    // and therefore we cannot use integer comparisons for them.
528                    // Avoid handling them, though this could be extended in the future.
529                    return;
530                }
531                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
532                else {
533                    return;
534                };
535                let Some(conds) = conditions.map(self.arena, |c| {
536                    Some(Condition {
537                        value,
538                        polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
539                        ..c
540                    })
541                }) else {
542                    return;
543                };
544                state.insert_value_idx(place, conds, &self.map);
545            }
546
547            _ => {}
548        }
549    }
550
551    #[instrument(level = "trace", skip(self))]
552    fn process_statement(
553        &mut self,
554        bb: BasicBlock,
555        stmt: &Statement<'tcx>,
556        state: &mut State<ConditionSet<'a>>,
557    ) {
558        let register_opportunity = |c: Condition| {
559            debug!(?bb, ?c.target, "register");
560            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
561        };
562
563        // Below, `lhs` is the return value of `mutated_statement`,
564        // the place to which `conditions` apply.
565
566        match &stmt.kind {
567            // If we expect `discriminant(place) ?= A`,
568            // we have an opportunity if `variant_index ?= A`.
569            StatementKind::SetDiscriminant { box place, variant_index } => {
570                let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
571                let enum_ty = place.ty(self.body, self.tcx).ty;
572                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
573                // Even if the discriminant write does nothing due to niches, it is UB to set the
574                // discriminant when the data does not encode the desired discriminant.
575                let Some(discr) =
576                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
577                else {
578                    return;
579                };
580                self.process_immediate(bb, discr_target, discr, state)
581            }
582            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
583            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
584                Operand::Copy(place) | Operand::Move(place),
585            )) => {
586                let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
587                conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity)
588            }
589            StatementKind::Assign(box (lhs_place, rhs)) => {
590                self.process_assign(bb, lhs_place, rhs, state)
591            }
592            _ => {}
593        }
594    }
595
596    #[instrument(level = "trace", skip(self, state, cost))]
597    fn recurse_through_terminator(
598        &mut self,
599        bb: BasicBlock,
600        // Pass a closure that may clone the state, as we don't want to do it each time.
601        state: impl FnOnce() -> State<ConditionSet<'a>>,
602        cost: &CostChecker<'_, 'tcx>,
603        depth: usize,
604    ) {
605        let term = self.body.basic_blocks[bb].terminator();
606        let place_to_flood = match term.kind {
607            // We come from a target, so those are not possible.
608            TerminatorKind::UnwindResume
609            | TerminatorKind::UnwindTerminate(_)
610            | TerminatorKind::Return
611            | TerminatorKind::TailCall { .. }
612            | TerminatorKind::Unreachable
613            | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
614            // Disallowed during optimizations.
615            TerminatorKind::FalseEdge { .. }
616            | TerminatorKind::FalseUnwind { .. }
617            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
618            // Cannot reason about inline asm.
619            TerminatorKind::InlineAsm { .. } => return,
620            // `SwitchInt` is handled specially.
621            TerminatorKind::SwitchInt { .. } => return,
622            // We can recurse, no thing particular to do.
623            TerminatorKind::Goto { .. } => None,
624            // Flood the overwritten place, and progress through.
625            TerminatorKind::Drop { place: destination, .. }
626            | TerminatorKind::Call { destination, .. } => Some(destination),
627            // Ignore, as this can be a no-op at codegen time.
628            TerminatorKind::Assert { .. } => None,
629        };
630
631        // We can recurse through this terminator.
632        let mut state = state();
633        if let Some(place_to_flood) = place_to_flood {
634            state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
635        }
636        self.find_opportunity(bb, state, cost.clone(), depth + 1)
637    }
638
639    #[instrument(level = "trace", skip(self))]
640    fn process_switch_int(
641        &mut self,
642        discr: &Operand<'tcx>,
643        targets: &SwitchTargets,
644        target_bb: BasicBlock,
645        state: &mut State<ConditionSet<'a>>,
646    ) {
647        debug_assert_ne!(target_bb, START_BLOCK);
648        debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1);
649
650        let Some(discr) = discr.place() else { return };
651        let discr_ty = discr.ty(self.body, self.tcx).ty;
652        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
653            return;
654        };
655        let Some(conditions) = state.try_get(discr.as_ref(), &self.map) else { return };
656
657        if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
658            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
659            debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1);
660
661            // We are inside `target_bb`. Since we have a single predecessor, we know we passed
662            // through the `SwitchInt` before arriving here. Therefore, we know that
663            // `discr == value`. If one condition can be fulfilled by `discr == value`,
664            // that's an opportunity.
665            for c in conditions.iter_matches(value) {
666                debug!(?target_bb, ?c.target, "register");
667                self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target });
668            }
669        } else if let Some((value, _, else_bb)) = targets.as_static_if()
670            && target_bb == else_bb
671        {
672            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
673
674            // We only know that `discr != value`. That's much weaker information than
675            // the equality we had in the previous arm. All we can conclude is that
676            // the replacement condition `discr != value` can be threaded, and nothing else.
677            for c in conditions.iter() {
678                if c.value == value && c.polarity == Polarity::Ne {
679                    debug!(?target_bb, ?c.target, "register");
680                    self.opportunities
681                        .push(ThreadingOpportunity { chain: vec![], target: c.target });
682                }
683            }
684        }
685    }
686}
687
688struct OpportunitySet {
689    opportunities: Vec<ThreadingOpportunity>,
690    /// For each bb, give the TOs in which it appears. The pair corresponds to the index
691    /// in `opportunities` and the index in `ThreadingOpportunity::chain`.
692    involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>,
693    /// Cache the number of predecessors for each block, as we clear the basic block cache..
694    predecessors: IndexVec<BasicBlock, usize>,
695}
696
697impl OpportunitySet {
698    fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet {
699        let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks);
700        for (index, to) in opportunities.iter().enumerate() {
701            for (ibb, &bb) in to.chain.iter().enumerate() {
702                involving_tos[bb].push((index, ibb));
703            }
704            involving_tos[to.target].push((index, to.chain.len()));
705        }
706        let predecessors = predecessor_count(body);
707        OpportunitySet { opportunities, involving_tos, predecessors }
708    }
709
710    /// Apply the opportunities on the graph.
711    fn apply(&mut self, body: &mut Body<'_>) {
712        for i in 0..self.opportunities.len() {
713            self.apply_once(i, body);
714        }
715    }
716
717    #[instrument(level = "trace", skip(self, body))]
718    fn apply_once(&mut self, index: usize, body: &mut Body<'_>) {
719        debug!(?self.predecessors);
720        debug!(?self.involving_tos);
721
722        // Check that `predecessors` satisfies its invariant.
723        debug_assert_eq!(self.predecessors, predecessor_count(body));
724
725        // Remove the TO from the vector to allow modifying the other ones later.
726        let op = &mut self.opportunities[index];
727        debug!(?op);
728        let op_chain = std::mem::take(&mut op.chain);
729        let op_target = op.target;
730        debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len());
731
732        let Some((current, chain)) = op_chain.split_first() else { return };
733        let basic_blocks = body.basic_blocks.as_mut();
734
735        // Invariant: the control-flow is well-formed at the end of each iteration.
736        let mut current = *current;
737        for &succ in chain {
738            debug!(?current, ?succ);
739
740            // `succ` must be a successor of `current`. If it is not, this means this TO is not
741            // satisfiable and a previous TO erased this edge, so we bail out.
742            if !basic_blocks[current].terminator().successors().any(|s| s == succ) {
743                debug!("impossible");
744                return;
745            }
746
747            // Fast path: `succ` is only used once, so we can reuse it directly.
748            if self.predecessors[succ] == 1 {
749                debug!("single");
750                current = succ;
751                continue;
752            }
753
754            let new_succ = basic_blocks.push(basic_blocks[succ].clone());
755            debug!(?new_succ);
756
757            // Replace `succ` by `new_succ` where it appears.
758            let mut num_edges = 0;
759            basic_blocks[current].terminator_mut().successors_mut(|s| {
760                if *s == succ {
761                    *s = new_succ;
762                    num_edges += 1;
763                }
764            });
765
766            // Update predecessors with the new block.
767            let _new_succ = self.predecessors.push(num_edges);
768            debug_assert_eq!(new_succ, _new_succ);
769            self.predecessors[succ] -= num_edges;
770            self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr);
771
772            // Replace the `current -> succ` edge by `current -> new_succ` in all the following
773            // TOs. This is necessary to avoid trying to thread through a non-existing edge. We
774            // use `involving_tos` here to avoid traversing the full set of TOs on each iteration.
775            let mut new_involved = Vec::new();
776            for &(to_index, in_to_index) in &self.involving_tos[current] {
777                // That TO has already been applied, do nothing.
778                if to_index <= index {
779                    continue;
780                }
781
782                let other_to = &mut self.opportunities[to_index];
783                if other_to.chain.get(in_to_index) != Some(&current) {
784                    continue;
785                }
786                let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target);
787                if *s == succ {
788                    // `other_to` references the `current -> succ` edge, so replace `succ`.
789                    *s = new_succ;
790                    new_involved.push((to_index, in_to_index + 1));
791                }
792            }
793
794            // The TOs that we just updated now reference `new_succ`. Update `involving_tos`
795            // in case we need to duplicate an edge starting at `new_succ` later.
796            let _new_succ = self.involving_tos.push(new_involved);
797            debug_assert_eq!(new_succ, _new_succ);
798
799            current = new_succ;
800        }
801
802        let current = &mut basic_blocks[current];
803        self.update_predecessor_count(current.terminator(), Update::Decr);
804        current.terminator_mut().kind = TerminatorKind::Goto { target: op_target };
805        self.predecessors[op_target] += 1;
806    }
807
808    fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) {
809        match incr {
810            Update::Incr => {
811                for s in terminator.successors() {
812                    self.predecessors[s] += 1;
813                }
814            }
815            Update::Decr => {
816                for s in terminator.successors() {
817                    self.predecessors[s] -= 1;
818                }
819            }
820        }
821    }
822}
823
824fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> {
825    let mut predecessors: IndexVec<_, _> =
826        body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect();
827    predecessors[START_BLOCK] += 1; // Account for the implicit entry edge.
828    predecessors
829}
830
831enum Update {
832    Incr,
833    Decr,
834}
835
836/// Compute the set of loop headers in the given body. We define a loop header as a block which has
837/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs.
838/// But if the CFG is already irreducible, there is no point in trying much harder.
839/// is already irreducible.
840fn loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> {
841    let mut loop_headers = DenseBitSet::new_empty(body.basic_blocks.len());
842    let dominators = body.basic_blocks.dominators();
843    // Only visit reachable blocks.
844    for (bb, bbdata) in traversal::preorder(body) {
845        for succ in bbdata.terminator().successors() {
846            if dominators.dominates(succ, bb) {
847                loop_headers.insert(succ);
848            }
849        }
850    }
851    loop_headers
852}