charon_lib/transform/control_flow/
ullbc_to_llbc.rs

1//! ULLBC to LLBC
2//!
3//! We reconstruct the control-flow in the Unstructured LLBC.
4//!
5//! The reconstruction algorithm is not written to be efficient (its complexity
6//! is probably very bad), but it was not written to be: this is still an early
7//! stage and we want the algorithm to generate the best reconstruction as
8//! possible. We still need to test the algorithm on more interesting examples,
9//! and will consider making it more efficient once it is a bit mature and well
10//! tested.
11//! Also note that we more importantly focus on making the algorithm sound: the
12//! reconstructed program must always be equivalent to the original MIR program,
13//! and the fact that the reconstruction preserves this property must be obvious.
14use itertools::Itertools;
15use petgraph::algo::dijkstra;
16use petgraph::algo::dominators::{Dominators, simple_fast};
17use petgraph::graphmap::DiGraphMap;
18use petgraph::visit::{
19    Dfs, DfsPostOrder, EdgeFiltered, EdgeRef, GraphRef, IntoNeighbors, VisitMap, Visitable, Walker,
20};
21use smallvec::SmallVec;
22use std::cmp::Reverse;
23use std::collections::HashSet;
24use std::mem;
25
26use crate::common::ensure_sufficient_stack;
27use crate::errors::sanity_check;
28use crate::ids::IndexVec;
29use crate::llbc_ast as tgt;
30use crate::meta::{Span, combine_span};
31use crate::transform::TransformCtx;
32use crate::transform::ctx::TransformPass;
33use crate::ullbc_ast::{self as src, BlockId};
34use crate::{ast::*, register_error};
35
36pub enum StackAction<N> {
37    PopPath,
38    Explore(N),
39}
40pub struct DfsWithPath<N, VM> {
41    /// The stack of nodes to visit
42    pub stack: Vec<StackAction<N>>,
43    /// The map of discovered nodes
44    pub discovered: VM,
45    /// The path from start node to current node.
46    pub path: Vec<N>,
47}
48impl<N, VM> DfsWithPath<N, VM>
49where
50    N: Copy + PartialEq,
51    VM: VisitMap<N>,
52{
53    /// Create a new **DfsWithPath**, using the graph's visitor map, and put **start** in the stack
54    /// of nodes to visit.
55    pub fn new<G>(graph: G, start: N) -> Self
56    where
57        G: GraphRef + Visitable<NodeId = N, Map = VM>,
58    {
59        Self {
60            stack: vec![StackAction::Explore(start)],
61            discovered: graph.visit_map(),
62            path: vec![],
63        }
64    }
65
66    /// Return the next node in the dfs, or **None** if the traversal is done.
67    pub fn next<G>(&mut self, graph: G) -> Option<N>
68    where
69        G: IntoNeighbors<NodeId = N>,
70    {
71        while let Some(action) = self.stack.pop() {
72            match action {
73                StackAction::Explore(node) => {
74                    if self.discovered.visit(node) {
75                        self.path.push(node);
76                        self.stack.push(StackAction::PopPath);
77                        for succ in graph.neighbors(node) {
78                            if !self.discovered.is_visited(&succ) {
79                                self.stack.push(StackAction::Explore(succ));
80                            }
81                        }
82                        return Some(node);
83                    }
84                }
85                StackAction::PopPath => {
86                    self.path.pop();
87                }
88            }
89        }
90        None
91    }
92}
93
94/// Arbitrary-precision numbers
95type BigUint = fraction::DynaInt<u64, fraction::BigUint>;
96type BigRational = fraction::Ratio<BigUint>;
97
98/// Control-Flow Graph
99type Cfg = DiGraphMap<src::BlockId, ()>;
100
101/// Information precomputed about a function's CFG.
102#[derive(Debug)]
103struct CfgInfo<'a> {
104    /// The CFG
105    pub cfg: Cfg,
106    /// The CFG where all the backward edges have been removed. Aka "forward CFG".
107    pub fwd_cfg: Cfg,
108    /// We consider the destination of the backward edges to be loop entries and
109    /// store them here.
110    pub loop_entries: HashSet<src::BlockId>,
111    /// The blocks whose terminators are a switch are stored here.
112    pub switch_blocks: HashSet<src::BlockId>,
113    /// Tree of which nodes dominates which other nodes.
114    #[expect(unused)]
115    pub dominator_tree: Dominators<BlockId>,
116    /// Computed data about each block.
117    pub block_data: IndexVec<BlockId, BlockData<'a>>,
118}
119
120#[derive(Debug)]
121struct BlockData<'a> {
122    pub id: BlockId,
123    pub contents: &'a src::BlockData,
124    /// The (unique) entrypoints of each loop. Unique because we error on irreducible cfgs.
125    pub is_loop_header: bool,
126    /// Whether this block is a switch.
127    pub is_switch: bool,
128    /// Blocks that have multiple incoming control-flow edges.
129    pub is_merge_target: bool,
130    /// Order in a reverse postorder numbering. `None` if the block is unreachable.
131    pub reverse_postorder: Option<u32>,
132    /// Nodes that this block immediately dominates. Sorted by reverse_postorder_id, with largest
133    /// id first.
134    pub immediately_dominates: SmallVec<[BlockId; 2]>,
135    /// List of loops inside of which this node is (loops are identified by their header). A node
136    /// is considered inside a loop if it is reachable from the loop header and if it can reach the
137    /// loop header using only the backwards edges into it (i.e. we don't count a path that enters
138    /// the loop header through a forward edge).
139    ///
140    /// Note that we might have to take a backward edge to reach the loop header, e.g.:
141    ///   'a: loop {
142    ///       // ...
143    ///       'b: loop {
144    ///           // ...
145    ///           if true {
146    ///               continue 'a;
147    ///           } else {
148    ///               if true {
149    ///                   break 'a;
150    ///               }
151    ///               // This node has to take two backward edges in order to reach the start of `'a`.
152    ///           }
153    ///       }
154    ///   }
155    ///
156    /// The restriction on backwards edges is for the following case:
157    ///   loop {
158    ///     loop {
159    ///       ..
160    ///     }
161    ///     // Not in inner loop
162    ///   }
163    ///
164    /// This is sorted by path order from the graph root.
165    pub within_loops: SmallVec<[BlockId; 2]>,
166    /// Node from where we can only reach error nodes (panic, etc.)
167    pub only_reach_error: bool,
168    /// List of reachable nodes, with the length of shortest path to them. Includes the current
169    /// node.
170    pub shortest_paths: hashbrown::HashMap<BlockId, usize>,
171    /// Let's say we put a quantity of water equal to 1 on the block, and the water flows downards.
172    /// Whenever there is a branching, the quantity of water gets equally divided between the
173    /// branches. When the control flows join, we put the water back together. The set below
174    /// computes the amount of water received by each descendant of the node.
175    ///
176    /// TODO: there must be a known algorithm which computes this, right?...
177    /// This is exactly this problems:
178    /// <https://stackoverflow.com/questions/78221666/algorithm-for-total-flow-through-weighted-directed-acyclic-graph>
179    /// TODO: the way I compute this is not efficient.
180    pub flow: IndexVec<BlockId, BigRational>,
181    /// Reconstructed information about loops and switches.
182    pub exit_info: ExitInfo,
183}
184
185#[derive(Debug, Default, Clone)]
186struct ExitInfo {
187    /// The loop exit
188    loop_exit: Option<src::BlockId>,
189    /// The switch exit.
190    switch_exit: Option<src::BlockId>,
191}
192
193/// Error indicating that the control-flow graph is not reducible. The contained block id is a
194/// block involved in an irreducible subgraph.
195struct Irreducible(BlockId);
196
197impl<'a> CfgInfo<'a> {
198    /// Build the CFGs (the "regular" CFG and the CFG without backward edges) and precompute a
199    /// bunch of graph information about the CFG.
200    fn build(ctx: &TransformCtx, body: &'a src::BodyContents) -> Result<Self, Irreducible> {
201        // The steps in this function follow a precise order, as each step typically requires the
202        // previous one.
203        let start_block = BlockId::ZERO;
204
205        let empty_flow = body.map_ref(|_| BigRational::new(0u64.into(), 1u64.into()));
206        let mut block_data: IndexVec<BlockId, BlockData> =
207            body.map_ref_indexed(|id, contents| BlockData {
208                id,
209                contents,
210                is_loop_header: false,
211                is_switch: false,
212                is_merge_target: false,
213                reverse_postorder: None,
214                immediately_dominates: Default::default(),
215                within_loops: Default::default(),
216                only_reach_error: false,
217                shortest_paths: Default::default(),
218                flow: empty_flow.clone(),
219                exit_info: Default::default(),
220            });
221
222        // Build the node graph (we ignore unwind paths for now).
223        let mut cfg = Cfg::new();
224        for (block_id, block) in body.iter_indexed() {
225            cfg.add_node(block_id);
226            for tgt in block.targets_ignoring_unwind() {
227                cfg.add_edge(block_id, tgt, ());
228            }
229        }
230
231        // Compute the dominator tree.
232        let dominator_tree = simple_fast(&cfg, start_block);
233
234        // Compute reverse postorder numbering.
235        for (i, block_id) in DfsPostOrder::new(&cfg, start_block).iter(&cfg).enumerate() {
236            let rev_post_id = body.len() - i;
237            block_data[block_id].reverse_postorder = Some(rev_post_id.try_into().unwrap());
238
239            // Store the dominator tree in `block_data`.
240            if let Some(dominator) = dominator_tree.immediate_dominator(block_id) {
241                block_data[dominator].immediately_dominates.push(block_id);
242            }
243
244            // Detect merge targets.
245            if cfg
246                .neighbors_directed(block_id, petgraph::Direction::Incoming)
247                .count()
248                >= 2
249            {
250                block_data[block_id].is_merge_target = true;
251            }
252        }
253
254        // Compute the forward graph (without backward edges). We do a dfs while keeping track of
255        // the path from the start node.
256        let mut fwd_cfg = Cfg::new();
257        let mut loop_entries = HashSet::new();
258        let mut switch_blocks = HashSet::new();
259        for block_id in Dfs::new(&cfg, start_block).iter(&cfg) {
260            fwd_cfg.add_node(block_id);
261
262            if body[block_id].terminator.kind.is_switch() {
263                switch_blocks.insert(block_id);
264                block_data[block_id].is_switch = true;
265            }
266
267            // Iterate over edges into this node (so that we can determine whether this node is a
268            // loop header).
269            for from in cfg.neighbors_directed(block_id, petgraph::Direction::Incoming) {
270                // Check if the edge is a backward edge.
271                if block_data[from].reverse_postorder >= block_data[block_id].reverse_postorder {
272                    // This is a backward edge
273                    block_data[block_id].is_loop_header = true;
274                    loop_entries.insert(block_id);
275                    // A cfg is reducible iff the target of every back edge dominates the
276                    // edge's source.
277                    if !dominator_tree.dominators(from).unwrap().contains(&block_id) {
278                        return Err(Irreducible(from));
279                    }
280                } else {
281                    fwd_cfg.add_edge(from, block_id, ());
282                }
283            }
284        }
285
286        // Finish filling in information.
287        for block_id in DfsPostOrder::new(&fwd_cfg, start_block).iter(&fwd_cfg) {
288            let block = &body[block_id];
289            let targets = cfg.neighbors(block_id).collect_vec();
290            let fwd_targets = fwd_cfg.neighbors(block_id).collect_vec();
291
292            // Compute the nodes that can only reach error nodes.
293            // The node can only reach error nodes if:
294            // - it is an error node;
295            // - or it has neighbors and they all lead to errors.
296            // Note that if there is a backward edge, `only_reach_error` cannot contain this
297            // node yet. In other words, this does not consider infinite loops as reaching an
298            // error node.
299            if block.terminator.is_error()
300                || (!targets.is_empty()
301                    && targets.iter().all(|&tgt| block_data[tgt].only_reach_error))
302            {
303                block_data[block_id].only_reach_error = true;
304            }
305
306            // Compute the flows between each pair of nodes.
307            let mut flow: IndexVec<src::BlockId, BigRational> =
308                mem::take(&mut block_data[block_id].flow);
309            if !fwd_targets.is_empty() {
310                // We need to divide the initial flow equally between the children
311                let factor = BigRational::new(1u64.into(), fwd_targets.len().into());
312
313                // For each child, multiply the flows of its own children by the ratio,
314                // and add.
315                for &child_id in &fwd_targets {
316                    // First, add the child itself
317                    flow[child_id] += factor.clone();
318
319                    // Then add its successors
320                    let child = &block_data[child_id];
321                    for grandchild in child.reachable_excluding_self() {
322                        // Flow from `child` to `grandchild`
323                        let child_flow = child.flow[grandchild].clone();
324                        flow[grandchild] += factor.clone() * child_flow;
325                    }
326                }
327            }
328            block_data[block_id].flow = flow;
329
330            // Compute shortest paths to all reachable nodes in the forward graph.
331            block_data[block_id].shortest_paths = dijkstra(&fwd_cfg, block_id, None, |_| 1usize);
332
333            // Fill in the rest of the domination data.
334            let mut dominatees = mem::take(&mut block_data[block_id].immediately_dominates);
335            dominatees.sort_by_key(|&child| block_data[child].reverse_postorder);
336            dominatees.reverse();
337            block_data[block_id].immediately_dominates = dominatees;
338        }
339
340        // Fill in the within_loop information. See the docs of `within_loops` to understand what
341        // we're computing.
342        let mut path_dfs = DfsWithPath::new(&cfg, start_block);
343        while let Some(block_id) = path_dfs.next(&cfg) {
344            // Store all the loops on the path to this
345            // node.
346            let mut within_loops: SmallVec<_> = path_dfs
347                .path
348                .iter()
349                .copied()
350                .filter(|&loop_id| block_data[loop_id].is_loop_header)
351                .collect();
352            // The loops that we can reach by taking a single backward edge.
353            let loops_directly_within = within_loops
354                .iter()
355                .copied()
356                .filter(|&loop_header| {
357                    cfg.neighbors_directed(loop_header, petgraph::Direction::Incoming)
358                        .any(|bid| block_data[block_id].shortest_paths.contains_key(&bid))
359                })
360                .collect_vec();
361            // The loops that we can reach by taking any number of backward edges.
362            let loops_indirectly_within: HashSet<_> = loops_directly_within
363                .iter()
364                .copied()
365                .flat_map(|loop_header| &block_data[loop_header].within_loops)
366                .chain(&loops_directly_within)
367                .copied()
368                .collect();
369            within_loops.retain(|id| loops_indirectly_within.contains(id));
370            block_data[block_id].within_loops = within_loops;
371        }
372
373        let mut cfg = CfgInfo {
374            cfg,
375            fwd_cfg,
376            loop_entries,
377            switch_blocks,
378            dominator_tree,
379            block_data,
380        };
381
382        // Pick an exit block for each loop, if we find one.
383        ExitInfo::compute_loop_exits(ctx, &mut cfg);
384
385        // Pick an exit block for each switch, if we find one.
386        ExitInfo::compute_switch_exits(&mut cfg);
387
388        Ok(cfg)
389    }
390
391    fn block_data(&self, block_id: BlockId) -> &BlockData<'a> {
392        &self.block_data[block_id]
393    }
394    // fn can_reach(&self, src: BlockId, tgt: BlockId) -> bool {
395    //     self.block_data[src].shortest_paths.contains_key(&tgt)
396    // }
397    fn topo_rank(&self, block_id: BlockId) -> u32 {
398        self.block_data[block_id].reverse_postorder.unwrap()
399    }
400    #[expect(unused)]
401    fn is_backward_edge(&self, src: BlockId, tgt: BlockId) -> bool {
402        self.block_data[src].reverse_postorder >= self.block_data[tgt].reverse_postorder
403            && self.cfg.contains_edge(src, tgt)
404    }
405
406    /// Check if the node is within the given loop.
407    fn is_within_loop(&self, loop_header: src::BlockId, block_id: src::BlockId) -> bool {
408        self.block_data[block_id]
409            .within_loops
410            .contains(&loop_header)
411    }
412
413    /// Check if all paths from `src` to nodes in `target_set` go through `through_node`. If `src`
414    /// is already in `target_set`, we ignore that empty path.
415    fn all_paths_go_through(
416        &self,
417        src: src::BlockId,
418        through_node: src::BlockId,
419        target_set: &HashSet<src::BlockId>,
420    ) -> bool {
421        let graph = EdgeFiltered::from_fn(&self.fwd_cfg, |edge| edge.source() != through_node);
422        !Dfs::new(&graph, src)
423            .iter(&graph)
424            .skip(1) // skip src
425            .any(|bid| target_set.contains(&bid))
426    }
427}
428
429impl BlockData<'_> {
430    fn shortest_paths_including_self(&self) -> impl Iterator<Item = (BlockId, usize)> {
431        self.shortest_paths.iter().map(|(bid, d)| (*bid, *d))
432    }
433    fn shortest_paths_excluding_self(&self) -> impl Iterator<Item = (BlockId, usize)> {
434        self.shortest_paths_including_self()
435            .filter(move |&(bid, _)| bid != self.id)
436    }
437    fn reachable_including_self(&self) -> impl Iterator<Item = BlockId> {
438        self.shortest_paths_including_self().map(|(bid, _)| bid)
439    }
440    fn reachable_excluding_self(&self) -> impl Iterator<Item = BlockId> {
441        self.shortest_paths_excluding_self().map(|(bid, _)| bid)
442    }
443    #[expect(unused)]
444    fn can_reach_excluding_self(&self, other: BlockId) -> bool {
445        self.shortest_paths.contains_key(&other) && self.id != other
446    }
447}
448
449/// See [`ExitInfo::compute_loop_exit_ranks`].
450#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
451struct LoopExitRank {
452    /// Number of paths we found going to this exit.
453    path_count: usize,
454    /// Distance from the loop header.
455    distance_from_header: Reverse<usize>,
456}
457
458impl ExitInfo {
459    /// Compute the first node on each path that exits the loop.
460    fn compute_loop_exit_starting_points(
461        cfg: &CfgInfo,
462        loop_header: src::BlockId,
463    ) -> Vec<src::BlockId> {
464        let mut loop_exits = Vec::new();
465        // Do a dfs from the loop header while keeping track of the path from the loop header to
466        // the current node.
467        let mut dfs = Dfs::new(&cfg.fwd_cfg, loop_header);
468        while let Some(block_id) = dfs.next(&cfg.fwd_cfg) {
469            // If we've exited all the loops after and including the target one, this node is an
470            // exit node for the target loop.
471            if !cfg.is_within_loop(loop_header, block_id) {
472                loop_exits.push(block_id);
473                // Don't explore any more paths from this node.
474                dfs.discovered.extend(cfg.fwd_cfg.neighbors(block_id));
475            }
476        }
477        loop_exits
478    }
479
480    /// Compute the loop exit candidates along with a rank.
481    ///
482    /// In the simple case, there is one exit node through which all exit paths go. We want to be
483    /// sure to catch that case, and when that's not possible we want to still find a node through
484    /// which a lot of exit paths go.
485    ///
486    /// To do that, we first count for each exit node how many exit paths go through it, and pick
487    /// the node with most occurrences. If there are many such nodes, we pick the one with shortest
488    /// distance from the loop header. Finally if there are still many such nodes, we keep the
489    /// first node found (the order in which we explore the graph is deterministic, and we use an
490    /// insertion-order hash map).
491    ///
492    /// Note that exit candidates will typically be referenced more than once for one loop. This
493    /// comes from the fact that whenever we reach a node outside the current loop, we register
494    /// this node as well as all its children as exit candidates.
495    /// Consider the following example:
496    /// ```text
497    /// while i < max {
498    ///     if cond {
499    ///         break;
500    ///     }
501    ///     s += i;
502    ///     i += 1
503    /// }
504    /// // All the below nodes are exit candidates (each of them is referenced twice)
505    /// s += 1;
506    /// return s;
507    /// ```
508    fn compute_loop_exit_ranks(
509        cfg: &CfgInfo,
510        loop_header: src::BlockId,
511    ) -> SeqHashMap<src::BlockId, LoopExitRank> {
512        let mut loop_exits: SeqHashMap<BlockId, LoopExitRank> = SeqHashMap::new();
513        for block_id in Self::compute_loop_exit_starting_points(cfg, loop_header) {
514            for bid in cfg.block_data(block_id).reachable_including_self() {
515                loop_exits
516                    .entry(bid)
517                    .or_insert_with(|| LoopExitRank {
518                        path_count: 0,
519                        distance_from_header: Reverse(
520                            cfg.block_data[loop_header].shortest_paths[&bid],
521                        ),
522                    })
523                    .path_count += 1;
524            }
525        }
526        loop_exits
527    }
528
529    /// A loop exit is any block reachable from the loop header that isn't inside the loop.
530    /// This function choses an exit for every loop. See `compute_loop_exit_ranks` for how we
531    /// select them.
532    ///
533    /// For example:
534    /// ```text
535    /// while ... {
536    ///    ...
537    ///    if ... {
538    ///        // We can't reach the loop entry from here: this is an exit
539    ///        // candidate
540    ///        return;
541    ///    }
542    /// }
543    /// // This is another exit candidate - and this is the one we want to use
544    /// // as the "real" exit...
545    /// ...
546    /// ```
547    ///
548    /// Once we listed all the exit candidates, we find the "best" one for every loop. The best
549    /// exit is the following one:
550    /// - it is the one which is used the most times (note that there can be
551    ///   several candidates which are referenced strictly more than once: see the
552    ///   comment below)
553    /// - if several exits have the same number of occurrences, we choose the one
554    ///   for which we goto the "earliest" (earliest meaning that the goto is close to
555    ///   the loop entry node in the AST). The reason is that all the loops should
556    ///   have an outer if ... then ... else ... which executes the loop body or goes
557    ///   to the exit (note that this is not necessarily the first
558    ///   if ... then ... else ... we find: loop conditions can be arbitrary
559    ///   expressions, containing branchings).
560    ///
561    /// # Several candidates for a loop exit:
562    /// =====================================
563    /// There used to be a sanity check to ensure there are no two different
564    /// candidates with exactly the same number of occurrences and distance from
565    /// the entry of the loop, if the number of occurrences is > 1.
566    ///
567    /// We removed it because it does happen, for instance here (the match
568    /// introduces an `unreachable` node, and it has the same number of
569    /// occurrences and the same distance to the loop entry as the `panic`
570    /// node):
571    ///
572    /// ```text
573    /// pub fn list_nth_mut_loop_pair<'a, T>(
574    ///     mut ls: &'a mut List<T>,
575    ///     mut i: u32,
576    /// ) -> &'a mut T {
577    ///     loop {
578    ///         match ls {
579    ///             List::Nil => {
580    ///                 panic!() // <-- best candidate
581    ///             }
582    ///             List::Cons(x, tl) => {
583    ///                 if i == 0 {
584    ///                     return x;
585    ///                 } else {
586    ///                     ls = tl;
587    ///                     i -= 1;
588    ///                 }
589    ///             }
590    ///             _ => {
591    ///               // Note that Rustc always introduces an unreachable branch after
592    ///               // desugaring matches.
593    ///               unreachable!(), // <-- best candidate
594    ///             }
595    ///         }
596    ///     }
597    /// }
598    /// ```
599    ///
600    /// When this happens we choose an exit candidate whose edges don't necessarily
601    /// lead to an error (above there are none, so we don't choose any exits). Note
602    /// that this last condition is important to prevent loops from being unnecessarily
603    /// nested:
604    ///
605    /// ```text
606    /// pub fn nested_loops_enum(step_out: usize, step_in: usize) -> usize {
607    ///     let mut s = 0;
608    ///
609    ///     for _ in 0..128 { // We don't want this loop to be nested with the loops below
610    ///         s += 1;
611    ///     }
612    ///
613    ///     for _ in 0..(step_out) {
614    ///         for _ in 0..(step_in) {
615    ///             s += 1;
616    ///         }
617    ///     }
618    ///
619    ///     s
620    /// }
621    /// ```
622    fn compute_loop_exits(ctx: &TransformCtx, cfg: &mut CfgInfo) {
623        for &loop_id in &cfg.loop_entries {
624            // Compute the candidates.
625            let loop_exits: SeqHashMap<BlockId, LoopExitRank> =
626                Self::compute_loop_exit_ranks(cfg, loop_id);
627            // We choose the exit with:
628            // - the most occurrences
629            // - the least total distance (if there are several possibilities)
630            // - doesn't necessarily lead to an error (panic, unreachable)
631            let best_exits: Vec<(BlockId, LoopExitRank)> =
632                loop_exits.into_iter().max_set_by_key(|&(_, rank)| rank);
633            // If there is exactly one best candidate, use it. Otherwise we need to split further.
634            let chosen_exit = match best_exits.into_iter().map(|(bid, _)| bid).exactly_one() {
635                Ok(best_exit) => Some(best_exit),
636                Err(best_exits) => {
637                    // Remove the candidates which only lead to errors (panic or unreachable).
638                    // If there is exactly one candidate we select it, otherwise we do not select any
639                    // exit.
640                    // We don't want to select any exit if we are in the below situation
641                    // (all paths lead to errors). We added a sanity check below to
642                    // catch the situations where there are several exits which don't
643                    // lead to errors.
644                    //
645                    // Example:
646                    // ========
647                    // ```
648                    // loop {
649                    //     match ls {
650                    //         List::Nil => {
651                    //             panic!() // <-- best candidate
652                    //         }
653                    //         List::Cons(x, tl) => {
654                    //             if i == 0 {
655                    //                 return x;
656                    //             } else {
657                    //                 ls = tl;
658                    //                 i -= 1;
659                    //             }
660                    //         }
661                    //         _ => {
662                    //           unreachable!(); // <-- best candidate (Rustc introduces an `unreachable` case)
663                    //         }
664                    //     }
665                    // }
666                    // ```
667                    best_exits
668                        .filter(|&bid| !cfg.block_data[bid].only_reach_error)
669                        .exactly_one()
670                        .map_err(|mut candidates| {
671                            // Adding this sanity check so that we can see when there are several
672                            // candidates.
673                            let span = cfg.block_data[loop_id].contents.terminator.span;
674                            sanity_check!(ctx, span, candidates.next().is_none());
675                        })
676                        .ok()
677                }
678            };
679            cfg.block_data[loop_id].exit_info.loop_exit = chosen_exit;
680        }
681    }
682
683    /// Let's consider the following piece of code:
684    /// ```text
685    /// if cond1 { ... } else { ... };
686    /// if cond2 { ... } else { ... };
687    /// ```
688    /// Once converted to MIR, the control-flow is destructured, which means we
689    /// have gotos everywhere. When reconstructing the control-flow, we have
690    /// to be careful about the point where we should join the two branches of
691    /// the first if.
692    /// For instance, if we don't notice they should be joined at some point (i.e,
693    /// whatever the branch we take, there is a moment when we go to the exact
694    /// same place, just before the second if), we might generate code like
695    /// this, with some duplicata:
696    /// ```text
697    /// if cond1 { ...; if cond2 { ... } else { ...} }
698    /// else { ...; if cond2 { ... } else { ...} }
699    /// ```
700    ///
701    /// Such a reconstructed program is valid, but it is definitely non-optimal:
702    /// it is very different from the original program (making it less clean and
703    /// clear), more bloated, and might involve duplicating the proof effort.
704    ///
705    /// For this reason, we need to find the "exit" of the first switch, which is
706    /// the point where the two branches join. Note that this can be a bit tricky,
707    /// because there may be more than two branches (if we do `switch(x) { ... }`),
708    /// and some of them might not join (if they contain a `break`, `panic`,
709    /// `return`, etc.).
710    ///
711    /// In order to compute the switch exits, we simply recursively compute a
712    /// topologically ordered set of "filtered successors" as follows (note
713    /// that we work in the CFG *without* back edges):
714    /// - for a block which doesn't branch (only one successor), the filtered
715    ///   successors is the set of reachable nodes.
716    /// - for a block which branches, we compute the nodes reachable from all
717    ///   the children, and find the "best" intersection between those.
718    ///   Note that we find the "best" intersection (a pair of branches which
719    ///   maximize the intersection of filtered successors) because some branches
720    ///   might never join the control-flow of the other branches, if they contain
721    ///   a `break`, `return`, `panic`, etc., like here:
722    ///   ```text
723    ///   if b { x = 3; } { return; }
724    ///   y += x;
725    ///   ...
726    ///   ```
727    /// Note that with nested switches, the branches of the inner switches might
728    /// goto the exits of the outer switches: for this reason, we give precedence
729    /// to the outer switches.
730    fn compute_switch_exits(cfg: &mut CfgInfo) {
731        // We need to give precedence to the outer switches: we thus iterate
732        // over the switch blocks in topological order.
733        let mut exits_set = HashSet::new();
734        for bid in cfg
735            .switch_blocks
736            .iter()
737            .copied()
738            .sorted_unstable_by_key(|&bid| (cfg.topo_rank(bid), bid))
739        {
740            let block_data = &cfg.block_data[bid];
741            // Find the best successor: this is the node with the highest flow, and the lowest
742            // topological rank. If several nodes have the same flow, we want to take the highest
743            // one in the hierarchy: hence the use of the topological rank.
744            //
745            // Ex.:
746            // ```text
747            // A  -- we start here
748            // |
749            // |---------------------------------------
750            // |            |            |            |
751            // B:(0.25,-1)  C:(0.25,-2)  D:(0.25,-3)  E:(0.25,-4)
752            // |            |            |
753            // |--------------------------
754            // |
755            // F:(0.75,-5)
756            // |
757            // |
758            // G:(0.75,-6)
759            // ```
760            // The "best" node (with the highest (flow, rank) in the graph above is F.
761            // If the switch is inside a loop, we also only consider exists that are inside that
762            // same loop. There must be one, otherwise the switch entry would not be inside the
763            // loop.
764            let current_loop = block_data.within_loops.last().copied();
765            let best_exit: Option<BlockId> = block_data
766                .reachable_excluding_self()
767                .filter(|&b| {
768                    current_loop.is_none_or(|current_loop| cfg.is_within_loop(current_loop, b))
769                })
770                .max_by_key(|&id| {
771                    let flow = &block_data.flow[id];
772                    let rank = Reverse(cfg.topo_rank(id));
773                    ((flow, rank), id)
774                });
775            // We have an exit candidate: we first check that it was not already taken by an
776            // external switch.
777            //
778            // We then check that we can't reach the exit of an external switch from one of the
779            // branches, without going through the exit candidate. We do this by simply checking
780            // that we can't reach any of the exits of outer switches.
781            //
782            // The reason is that it can lead to code like the following:
783            // ```
784            // if ... { // if #1
785            //   if ... { // if #2
786            //     ...
787            //     // here, we have a `goto b1`, where b1 is the exit
788            //     // of if #2: we thus stop translating the blocks.
789            //   }
790            //   else {
791            //     ...
792            //     // here, we have a `goto b2`, where b2 is the exit
793            //     // of if #1: we thus stop translating the blocks.
794            //   }
795            //   // We insert code for the block b1 here (which is the exit of
796            //   // the exit of if #2). However, this block should only
797            //   // be executed in the branch "then" of the if #2, not in
798            //   // the branch "else".
799            //   ...
800            // }
801            // else {
802            //   ...
803            // }
804            // ```
805            if let Some(exit_id) = best_exit
806                && !exits_set.contains(&exit_id)
807                && cfg.all_paths_go_through(bid, exit_id, &exits_set)
808            {
809                exits_set.insert(exit_id);
810                cfg.block_data[bid].exit_info.switch_exit = Some(exit_id);
811            }
812        }
813    }
814}
815
816enum GotoKind {
817    Break(usize),
818    Continue(usize),
819    NextBlock,
820    Goto,
821}
822
823type Depth = usize;
824
825#[derive(Debug, Clone, Copy)]
826enum SpecialJump {
827    /// When encountering this block, `continue` to the given depth.
828    LoopContinue(Depth),
829    /// When encountering this block, `break` to the given depth. This comes from a loop.
830    LoopBreak(Depth),
831    /// When encountering this block, `break` to the given depth. This is a `loop` context
832    /// introduced only for forward jumps.
833    ForwardBreak(Depth),
834    /// When encountering this block, do nothing, as this is the next block that will be
835    /// translated.
836    NextBlock,
837}
838
839enum ReconstructMode {
840    /// Reconstruct using flow heuristics.
841    Flow,
842    /// Reconstruct using the algorithm from "Beyond Relooper" (https://dl.acm.org/doi/10.1145/3547621).
843    /// A useful invariant is that the block at the top of the jump stack is the block where
844    /// control-flow will jump naturally at the end of the current block.
845    Reloop,
846}
847
848struct ReconstructCtx<'a> {
849    cfg: CfgInfo<'a>,
850    /// The depth of `loop` contexts we may `break`/`continue` to.
851    break_context_depth: Depth,
852    /// Stack of block ids that should be translated to special jumps (`break`/`continue`/do
853    /// nothing) in the current context.
854    /// The block where control-flow continues naturally after this block is kept at the top of the
855    /// stack.
856    special_jump_stack: Vec<(BlockId, SpecialJump)>,
857    mode: ReconstructMode,
858}
859
860impl<'a> ReconstructCtx<'a> {
861    fn build(ctx: &TransformCtx, src_body: &'a src::ExprBody) -> Result<Self, Irreducible> {
862        // Compute all sorts of graph-related information about the control-flow graph, including
863        // reachability, the dominator tree, loop entries, and loop/switch exits.
864        let cfg = CfgInfo::build(ctx, &src_body.body)?;
865
866        // Translate the body by reconstructing the loops and the
867        // conditional branchings.
868        let use_relooper = false;
869        Ok(ReconstructCtx {
870            cfg,
871            break_context_depth: 0,
872            special_jump_stack: Vec::new(),
873            mode: if use_relooper {
874                ReconstructMode::Reloop
875            } else {
876                ReconstructMode::Flow
877            },
878        })
879    }
880
881    fn translate_statement(&self, st: &src::Statement) -> tgt::Statement {
882        let src_span = st.span;
883        let st = match st.kind.clone() {
884            src::StatementKind::Assign(place, rvalue) => tgt::StatementKind::Assign(place, rvalue),
885            src::StatementKind::SetDiscriminant(place, variant_id) => {
886                tgt::StatementKind::SetDiscriminant(place, variant_id)
887            }
888            src::StatementKind::CopyNonOverlapping(copy) => {
889                tgt::StatementKind::CopyNonOverlapping(copy)
890            }
891            src::StatementKind::StorageLive(var_id) => tgt::StatementKind::StorageLive(var_id),
892            src::StatementKind::StorageDead(var_id) => tgt::StatementKind::StorageDead(var_id),
893            src::StatementKind::Deinit(place) => tgt::StatementKind::Deinit(place),
894            src::StatementKind::Assert(assert) => tgt::StatementKind::Assert(assert),
895            src::StatementKind::Nop => tgt::StatementKind::Nop,
896        };
897        tgt::Statement::new(src_span, st)
898    }
899
900    fn get_goto_kind(&self, target_block: src::BlockId) -> GotoKind {
901        match self
902            .special_jump_stack
903            .iter()
904            .rev()
905            .enumerate()
906            .find(|(_, (b, _))| *b == target_block)
907        {
908            Some((i, (_, jump_target))) => match jump_target {
909                // The top of the stack is where control-flow goes naturally, no need to add a
910                // `break`/`continue`.
911                SpecialJump::LoopContinue(_) | SpecialJump::ForwardBreak(_)
912                    if i == 0 && matches!(self.mode, ReconstructMode::Reloop) =>
913                {
914                    GotoKind::NextBlock
915                }
916                SpecialJump::LoopContinue(depth) => {
917                    GotoKind::Continue(self.break_context_depth - depth)
918                }
919                SpecialJump::ForwardBreak(depth) | SpecialJump::LoopBreak(depth) => {
920                    GotoKind::Break(self.break_context_depth - depth)
921                }
922                SpecialJump::NextBlock => GotoKind::NextBlock,
923            },
924            // Translate the block without a jump.
925            None => GotoKind::Goto,
926        }
927    }
928
929    /// Translate a jump to the given block. The span is used to create the jump statement, if any.
930    fn translate_jump(&mut self, span: Span, target_block: src::BlockId) -> tgt::Block {
931        let st = match self.get_goto_kind(target_block) {
932            GotoKind::Break(index) => tgt::StatementKind::Break(index),
933            GotoKind::Continue(index) => tgt::StatementKind::Continue(index),
934            GotoKind::NextBlock => tgt::StatementKind::Nop,
935            // "Standard" goto: we recursively translate the block.
936            GotoKind::Goto => return self.translate_block(target_block),
937        };
938        tgt::Statement::new(span, st).into_block()
939    }
940
941    fn translate_terminator(&mut self, terminator: &src::Terminator) -> tgt::Block {
942        let src_span = terminator.span;
943
944        match &terminator.kind {
945            src::TerminatorKind::Abort(kind) => {
946                tgt::Statement::new(src_span, tgt::StatementKind::Abort(kind.clone())).into_block()
947            }
948            src::TerminatorKind::Return => {
949                tgt::Statement::new(src_span, tgt::StatementKind::Return).into_block()
950            }
951            src::TerminatorKind::UnwindResume => {
952                tgt::Statement::new(src_span, tgt::StatementKind::Abort(AbortKind::Panic(None)))
953                    .into_block()
954            }
955            src::TerminatorKind::Call {
956                call,
957                target,
958                on_unwind: _,
959            } => {
960                // TODO: Have unwinds in the LLBC
961                let st = tgt::Statement::new(src_span, tgt::StatementKind::Call(call.clone()));
962                let mut block = self.translate_jump(terminator.span, *target);
963                block.statements.insert(0, st);
964                block
965            }
966            src::TerminatorKind::Drop {
967                kind,
968                place,
969                tref,
970                target,
971                on_unwind: _,
972            } => {
973                // TODO: Have unwinds in the LLBC
974                let st = tgt::Statement::new(
975                    src_span,
976                    tgt::StatementKind::Drop(place.clone(), tref.clone(), kind.clone()),
977                );
978                let mut block = self.translate_jump(terminator.span, *target);
979                block.statements.insert(0, st);
980                block
981            }
982            src::TerminatorKind::Goto { target } => self.translate_jump(terminator.span, *target),
983            src::TerminatorKind::Switch { discr, targets } => {
984                // Translate the target expressions
985                let switch = match &targets {
986                    src::SwitchTargets::If(then_tgt, else_tgt) => {
987                        let then_block = self.translate_jump(terminator.span, *then_tgt);
988                        let else_block = self.translate_jump(terminator.span, *else_tgt);
989                        tgt::Switch::If(discr.clone(), then_block, else_block)
990                    }
991                    src::SwitchTargets::SwitchInt(int_ty, targets, otherwise) => {
992                        // Note that some branches can be grouped together, like
993                        // here:
994                        // ```
995                        // match e {
996                        //   E::V1 | E::V2 => ..., // Grouped
997                        //   E::V3 => ...
998                        // }
999                        // ```
1000                        // We detect this by checking if a block has already been
1001                        // translated as one of the branches of the switch.
1002                        //
1003                        // Rk.: note there may be intermediate gotos depending
1004                        // on the MIR we use. Typically, we manage to detect the
1005                        // grouped branches with Optimized MIR, but not with Promoted
1006                        // MIR. See the comment in "tests/src/matches.rs".
1007
1008                        // We link block ids to:
1009                        // - vector of matched integer values
1010                        // - translated blocks
1011                        let mut branches: SeqHashMap<src::BlockId, (Vec<Literal>, tgt::Block)> =
1012                            SeqHashMap::new();
1013
1014                        // Translate the children expressions
1015                        for (v, bid) in targets.iter() {
1016                            // Check if the block has already been translated:
1017                            // if yes, it means we need to group branches
1018                            if branches.contains_key(bid) {
1019                                // Already translated: add the matched value to
1020                                // the list of values
1021                                let branch = branches.get_mut(bid).unwrap();
1022                                branch.0.push(v.clone());
1023                            } else {
1024                                // Not translated: translate it
1025                                let block = self.translate_jump(terminator.span, *bid);
1026                                // We use the terminator span information in case then
1027                                // then statement is `None`
1028                                branches.insert(*bid, (vec![v.clone()], block));
1029                            }
1030                        }
1031                        let targets_blocks: Vec<(Vec<Literal>, tgt::Block)> =
1032                            branches.into_iter().map(|(_, x)| x).collect();
1033
1034                        let otherwise_block = self.translate_jump(terminator.span, *otherwise);
1035
1036                        // Translate
1037                        tgt::Switch::SwitchInt(
1038                            discr.clone(),
1039                            *int_ty,
1040                            targets_blocks,
1041                            otherwise_block,
1042                        )
1043                    }
1044                };
1045
1046                // Return
1047                let span = tgt::combine_switch_targets_span(&switch);
1048                let span = combine_span(&src_span, &span);
1049                let st = tgt::StatementKind::Switch(switch);
1050                tgt::Statement::new(span, st).into_block()
1051            }
1052        }
1053    }
1054
1055    /// Translate just the block statements and terminator.
1056    fn translate_block_itself(&mut self, block_id: src::BlockId) -> tgt::Block {
1057        let block = self.cfg.block_data[block_id].contents;
1058        // Translate the statements inside the block
1059        let statements = block
1060            .statements
1061            .iter()
1062            .map(|st| self.translate_statement(st))
1063            .collect_vec();
1064        // Translate the terminator.
1065        let terminator = self.translate_terminator(&block.terminator);
1066        // Prepend the statements to the terminator.
1067        if let Some(st) = tgt::Block::from_seq(statements) {
1068            st.merge(terminator)
1069        } else {
1070            terminator
1071        }
1072    }
1073
1074    /// Translate a block including surrounding control-flow like looping.
1075    fn translate_block(&mut self, block_id: src::BlockId) -> tgt::Block {
1076        ensure_sufficient_stack(|| self.translate_block_inner(block_id))
1077    }
1078    fn translate_block_inner(&mut self, block_id: src::BlockId) -> tgt::Block {
1079        // Some of the blocks we might jump to inside this tree can't be translated as normal
1080        // blocks: the loop backward edges must become `continue`s and the merge nodes may need
1081        // some care if we're jumping to them from distant locations.
1082        // For this purpose, we push to the `special_jump_stack` the block ids that must be
1083        // translated specially. In `translate_jump` we check the stack. At the end of this
1084        // function we restore the stack to its previous state.
1085        let old_context_depth = self.special_jump_stack.len();
1086        let block_data = &self.cfg.block_data[block_id];
1087        let span = block_data.contents.terminator.span;
1088
1089        // Catch jumps to the loop header or loop exit.
1090        if block_data.is_loop_header {
1091            self.break_context_depth += 1;
1092            if let ReconstructMode::Flow = self.mode
1093                && let Some(exit_id) = block_data.exit_info.loop_exit
1094            {
1095                self.special_jump_stack
1096                    .push((exit_id, SpecialJump::LoopBreak(self.break_context_depth)));
1097            }
1098            // Put the next block at the top of the stack.
1099            self.special_jump_stack.push((
1100                block_id,
1101                SpecialJump::LoopContinue(self.break_context_depth),
1102            ));
1103        }
1104
1105        // Catch jumps to a merge node.
1106        match self.mode {
1107            ReconstructMode::Flow => {
1108                // We only support next-block jumps to merge nodes.
1109                if let Some(bid) = block_data.exit_info.switch_exit
1110                    && !block_data.is_loop_header
1111                {
1112                    self.special_jump_stack.push((bid, SpecialJump::NextBlock));
1113                }
1114            }
1115            ReconstructMode::Reloop => {
1116                // We support forward-jumps using `break`
1117                // The child with highest postorder numbering is nested outermost in this scheme.
1118                let merge_children = block_data
1119                    .immediately_dominates
1120                    .iter()
1121                    .copied()
1122                    .filter(|&child| self.cfg.block_data[child].is_merge_target);
1123                for child in merge_children {
1124                    self.break_context_depth += 1;
1125                    self.special_jump_stack
1126                        .push((child, SpecialJump::ForwardBreak(self.break_context_depth)));
1127                }
1128            }
1129        }
1130
1131        // Translate this block. Any jumps to a loop header or a merge node will be replaced with
1132        // `continue`/`break`.
1133        let mut block = self.translate_block_itself(block_id);
1134
1135        // Reset the state to what it was previously, and translate what remains.
1136        let new_block = move |kind| tgt::Statement::new(block.span, kind).into_block();
1137        while self.special_jump_stack.len() > old_context_depth {
1138            match self.special_jump_stack.pop().unwrap() {
1139                (_loop_header, SpecialJump::LoopContinue(_)) => {
1140                    self.break_context_depth -= 1;
1141                    block = new_block(tgt::StatementKind::Loop(block));
1142                }
1143                (next_block, SpecialJump::ForwardBreak(_)) => {
1144                    self.break_context_depth -= 1;
1145                    // We add a `loop { ...; break }` so that we can use `break` to jump forward.
1146                    block = block.merge(new_block(tgt::StatementKind::Break(0)));
1147                    block = new_block(tgt::StatementKind::Loop(block));
1148                    // We must translate the merge nodes after the block used for forward jumps to
1149                    // them.
1150                    let next_block = self.translate_jump(span, next_block);
1151                    block = block.merge(next_block);
1152                }
1153                (next_block, SpecialJump::NextBlock) | (next_block, SpecialJump::LoopBreak(..)) => {
1154                    let next_block = self.translate_jump(span, next_block);
1155                    block = block.merge(next_block);
1156                }
1157            }
1158        }
1159        block
1160    }
1161}
1162
1163fn translate_body(ctx: &mut TransformCtx, body: &mut gast::Body) {
1164    use gast::Body::{Structured, Unstructured};
1165    let Unstructured(src_body) = body else {
1166        panic!("Called `ullbc_to_llbc` on an already restructured body")
1167    };
1168    trace!("About to translate to ullbc: {:?}", src_body.span);
1169
1170    // Calculate info about the graph and heuristically determine loop and switch exit blocks.
1171    let start_block = BlockId::ZERO;
1172    let mut ctx = match ReconstructCtx::build(ctx, src_body) {
1173        Ok(ctx) => ctx,
1174        Err(Irreducible(bid)) => {
1175            let span = src_body.body[bid].terminator.span;
1176            register_error!(
1177                ctx,
1178                span,
1179                "the control-flow graph of this function is not reducible"
1180            );
1181            panic!("can't reconstruct irreducible control-flow")
1182        }
1183    };
1184    // Translate the blocks using the computed data.
1185    let tgt_body = ctx.translate_block(start_block);
1186
1187    let tgt_body = tgt::ExprBody {
1188        span: src_body.span,
1189        locals: src_body.locals.clone(),
1190        bound_body_regions: src_body.bound_body_regions,
1191        body: tgt_body,
1192        comments: src_body.comments.clone(),
1193    };
1194    *body = Structured(tgt_body);
1195}
1196
1197pub struct Transform;
1198impl TransformPass for Transform {
1199    fn transform_ctx(&self, ctx: &mut TransformCtx) {
1200        // Translate the bodies one at a time.
1201        ctx.for_each_body(|ctx, body| {
1202            translate_body(ctx, body);
1203        });
1204
1205        if ctx.options.print_built_llbc {
1206            eprintln!("# LLBC resulting from control-flow reconstruction:\n\n{ctx}\n",);
1207        } else {
1208            trace!("# LLBC resulting from control-flow reconstruction:\n\n{ctx}\n",);
1209        }
1210    }
1211}