Skip to main content

charon_lib/ast/
ullbc_ast_utils.rs

1//! Implementations for [crate::ullbc_ast]
2use smallvec::{SmallVec, smallvec};
3
4use crate::ids::IndexVec;
5use crate::meta::Span;
6use crate::ullbc_ast::*;
7use std::collections::HashMap;
8use std::mem;
9use std::ops::{Index, IndexMut};
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct StmtLoc {
13    pub block: BlockId,
14    pub statement: usize,
15}
16
17impl StmtLoc {
18    pub fn new(block: BlockId, statement: usize) -> Self {
19        StmtLoc { block, statement }
20    }
21
22    pub fn block_start(block: BlockId) -> Self {
23        StmtLoc {
24            block,
25            statement: 0,
26        }
27    }
28
29    pub fn after(self) -> Self {
30        StmtLoc {
31            block: self.block,
32            statement: self.statement + 1,
33        }
34    }
35}
36
37impl SwitchTargets {
38    pub fn targets(&self) -> SmallVec<[BlockId; 2]> {
39        match self {
40            SwitchTargets::If(then_tgt, else_tgt) => {
41                smallvec![*then_tgt, *else_tgt]
42            }
43            SwitchTargets::SwitchInt(_, targets, otherwise) => targets
44                .iter()
45                .map(|(_, t)| t)
46                .chain([otherwise])
47                .copied()
48                .collect(),
49        }
50    }
51    pub fn targets_mut(&mut self) -> SmallVec<[&mut BlockId; 2]> {
52        match self {
53            SwitchTargets::If(then_tgt, else_tgt) => {
54                smallvec![then_tgt, else_tgt]
55            }
56            SwitchTargets::SwitchInt(_, targets, otherwise) => targets
57                .iter_mut()
58                .map(|(_, t)| t)
59                .chain([otherwise])
60                .collect(),
61        }
62    }
63}
64
65impl Statement {
66    pub fn new(span: Span, kind: StatementKind) -> Self {
67        Statement {
68            span,
69            kind,
70            comments_before: vec![],
71        }
72    }
73}
74
75impl Terminator {
76    pub fn new(span: Span, kind: TerminatorKind) -> Self {
77        Terminator {
78            span,
79            kind,
80            comments_before: vec![],
81        }
82    }
83    pub fn goto(span: Span, target: BlockId) -> Self {
84        Self::new(span, TerminatorKind::Goto { target })
85    }
86    /// Whether this terminator is an unconditional error (panic).
87    pub fn is_error(&self) -> bool {
88        use TerminatorKind::*;
89        match &self.kind {
90            Abort(..) => true,
91            Goto { .. }
92            | Switch { .. }
93            | InlineAsm { .. }
94            | Return
95            | Call { .. }
96            | Drop { .. }
97            | UnwindResume
98            | Assert { .. } => false,
99        }
100    }
101
102    pub fn into_block(self) -> BlockData {
103        BlockData {
104            statements: vec![],
105            terminator: self,
106        }
107    }
108
109    pub fn targets(&self) -> SmallVec<[BlockId; 2]> {
110        match &self.kind {
111            TerminatorKind::Goto { target } => {
112                smallvec![*target]
113            }
114            TerminatorKind::Switch { targets, .. } => targets.targets(),
115            TerminatorKind::InlineAsm {
116                targets, on_unwind, ..
117            } => targets.iter().copied().chain([*on_unwind]).collect(),
118            TerminatorKind::Call {
119                target, on_unwind, ..
120            }
121            | TerminatorKind::Drop {
122                target, on_unwind, ..
123            }
124            | TerminatorKind::Assert {
125                target, on_unwind, ..
126            } => smallvec![*target, *on_unwind],
127            TerminatorKind::Abort(..) | TerminatorKind::Return | TerminatorKind::UnwindResume => {
128                smallvec![]
129            }
130        }
131    }
132    pub fn targets_mut(&mut self) -> SmallVec<[&mut BlockId; 2]> {
133        match &mut self.kind {
134            TerminatorKind::Goto { target } => {
135                smallvec![target]
136            }
137            TerminatorKind::Switch { targets, .. } => targets.targets_mut(),
138            TerminatorKind::InlineAsm {
139                targets, on_unwind, ..
140            } => targets.iter_mut().chain([on_unwind]).collect(),
141            TerminatorKind::Call {
142                target, on_unwind, ..
143            }
144            | TerminatorKind::Drop {
145                target, on_unwind, ..
146            }
147            | TerminatorKind::Assert {
148                target, on_unwind, ..
149            } => smallvec![target, on_unwind],
150            TerminatorKind::Abort(..) | TerminatorKind::Return | TerminatorKind::UnwindResume => {
151                smallvec![]
152            }
153        }
154    }
155
156    pub fn targets_ignoring_unwind(&self) -> SmallVec<[BlockId; 2]> {
157        match &self.kind {
158            TerminatorKind::Goto { target } => {
159                smallvec![*target]
160            }
161            TerminatorKind::Switch { targets, .. } => targets.targets(),
162            TerminatorKind::InlineAsm { targets, .. } => targets.iter().copied().collect(),
163            TerminatorKind::Call { target, .. }
164            | TerminatorKind::Drop { target, .. }
165            | TerminatorKind::Assert { target, .. } => {
166                smallvec![*target]
167            }
168            TerminatorKind::Abort(..) | TerminatorKind::Return | TerminatorKind::UnwindResume => {
169                smallvec![]
170            }
171        }
172    }
173}
174
175impl BlockData {
176    /// Build a block that's just a goto terminator.
177    pub fn new_goto(span: Span, target: BlockId) -> Self {
178        BlockData {
179            statements: vec![],
180            terminator: Terminator::goto(span, target),
181        }
182    }
183    pub fn as_goto(&self) -> Option<BlockId> {
184        if let TerminatorKind::Goto { target } = self.terminator.kind {
185            Some(target)
186        } else {
187            None
188        }
189    }
190    pub fn as_trivial_goto(&self) -> Option<BlockId> {
191        self.as_goto().filter(|_| {
192            self.statements
193                .iter()
194                .all(|st| matches!(st.kind, StatementKind::Nop))
195        })
196    }
197
198    pub fn as_abort(&self) -> Option<AbortKind> {
199        if self.statements.iter().all(|st| st.kind.is_storage_live())
200            && let TerminatorKind::Abort(abort) = &self.terminator.kind
201        {
202            Some(abort.clone())
203        } else {
204            None
205        }
206    }
207
208    /// Build a block that's UB to reach.
209    pub fn new_unreachable() -> Self {
210        Terminator::new(
211            Span::dummy(),
212            TerminatorKind::Abort(AbortKind::UndefinedBehavior),
213        )
214        .into_block()
215    }
216
217    pub fn targets(&self) -> SmallVec<[BlockId; 2]> {
218        self.terminator.targets()
219    }
220    pub fn targets_ignoring_unwind(&self) -> SmallVec<[BlockId; 2]> {
221        self.terminator.targets_ignoring_unwind()
222    }
223
224    /// Apply a transformer to all the statements.
225    ///
226    /// The transformer should:
227    /// - mutate the current statement in place
228    /// - return the sequence of statements to introduce before the current statement
229    pub fn transform<F: FnMut(&mut Statement) -> Vec<Statement>>(&mut self, mut f: F) {
230        self.transform_sequences_fwd(|slice| {
231            let new_statements = f(&mut slice[0]);
232            if new_statements.is_empty() {
233                vec![]
234            } else {
235                vec![(0, new_statements)]
236            }
237        });
238    }
239
240    /// Helper, see `transform_sequences_fwd` and `transform_sequences_bwd`.
241    fn transform_sequences<F>(&mut self, mut f: F, forward: bool)
242    where
243        F: FnMut(&mut [Statement]) -> Vec<(usize, Vec<Statement>)>,
244    {
245        let mut to_insert = vec![];
246        let mut final_len = self.statements.len();
247        if forward {
248            for i in 0..self.statements.len() {
249                let new_to_insert = f(&mut self.statements[i..]);
250                to_insert.extend(new_to_insert.into_iter().map(|(j, stmts)| {
251                    final_len += stmts.len();
252                    (i + j, stmts)
253                }));
254            }
255        } else {
256            for i in (0..self.statements.len()).rev() {
257                let new_to_insert = f(&mut self.statements[i..]);
258                to_insert.extend(new_to_insert.into_iter().map(|(j, stmts)| {
259                    final_len += stmts.len();
260                    (i + j, stmts)
261                }));
262            }
263        }
264        if !to_insert.is_empty() {
265            to_insert.sort_by_key(|(i, _)| *i);
266            // Make it so the first element is always at the end so we can pop it.
267            to_insert.reverse();
268            // Construct the merged list of statements.
269            let old_statements = mem::replace(&mut self.statements, Vec::with_capacity(final_len));
270            for (i, stmt) in old_statements.into_iter().enumerate() {
271                while let Some((j, _)) = to_insert.last()
272                    && *j == i
273                {
274                    let (_, mut stmts) = to_insert.pop().unwrap();
275                    self.statements.append(&mut stmts);
276                }
277                self.statements.push(stmt);
278            }
279        }
280    }
281
282    /// Apply a transformer to all the statements.
283    ///
284    /// The transformer should:
285    /// - mutate the current statements in place
286    /// - return a list of `(i, statements)` where `statements` will be inserted before index `i`.
287    pub fn transform_sequences_fwd<F>(&mut self, f: F)
288    where
289        F: FnMut(&mut [Statement]) -> Vec<(usize, Vec<Statement>)>,
290    {
291        self.transform_sequences(f, true);
292    }
293
294    /// Apply a transformer to all the statements.
295    ///
296    /// The transformer should:
297    /// - mutate the current statements in place
298    /// - return a list of `(i, statements)` where `statements` will be inserted before index `i`.
299    pub fn transform_sequences_bwd<F>(&mut self, f: F)
300    where
301        F: FnMut(&mut [Statement]) -> Vec<(usize, Vec<Statement>)>,
302    {
303        self.transform_sequences(f, false);
304    }
305}
306
307impl ExprBody {
308    /// Returns a map from blocks in this body to their abort kind, if they correspond to an
309    /// abort block (ie. a block with no statements and an [TerminatorKind::Abort] terminator).
310    pub fn as_abort_map(&self) -> HashMap<BlockId, AbortKind> {
311        self.body
312            .iter_enumerated()
313            .filter_map(|(bid, block)| block.as_abort().map(|abort| (bid, abort)))
314            .collect()
315    }
316
317    pub fn transform_sequences_fwd<F>(&mut self, mut f: F)
318    where
319        F: FnMut(BlockId, &mut Locals, &mut [Statement]) -> Vec<(usize, Vec<Statement>)>,
320    {
321        for (id, block) in &mut self.body.iter_mut_enumerated() {
322            block.transform_sequences_fwd(|seq| f(id, &mut self.locals, seq));
323        }
324    }
325
326    pub fn transform_sequences_bwd<F>(&mut self, mut f: F)
327    where
328        F: FnMut(&mut Locals, &mut [Statement]) -> Vec<(usize, Vec<Statement>)>,
329    {
330        for block in &mut self.body {
331            block.transform_sequences_bwd(|seq| f(&mut self.locals, seq));
332        }
333    }
334
335    /// Apply a function to all the statements, in a bottom-up manner.
336    pub fn visit_statements<F: FnMut(&mut Statement)>(&mut self, mut f: F) {
337        for block in self.body.iter_mut().rev() {
338            for st in block.statements.iter_mut().rev() {
339                f(st);
340            }
341        }
342    }
343}
344
345impl Index<StmtLoc> for ExprBody {
346    type Output = Statement;
347    fn index(&self, loc: StmtLoc) -> &Self::Output {
348        &self.body[loc.block].statements[loc.statement]
349    }
350}
351
352impl IndexMut<StmtLoc> for ExprBody {
353    fn index_mut(&mut self, loc: StmtLoc) -> &mut Self::Output {
354        &mut self.body[loc.block].statements[loc.statement]
355    }
356}
357
358/// Helper to construct a small ullbc body.
359pub struct BodyBuilder {
360    /// The span to use for everything.
361    pub span: Span,
362    /// Body under construction.
363    pub body: ExprBody,
364    /// Block onto which we're adding statements. Its terminator is always `Return`.
365    pub current_block: BlockId,
366    /// Block to unwind to; created on demand.
367    pub unwind_block: Option<BlockId>,
368}
369
370fn mk_block(span: Span, term: TerminatorKind) -> BlockData {
371    BlockData {
372        statements: vec![],
373        terminator: Terminator::new(span, term),
374    }
375}
376
377impl BodyBuilder {
378    pub fn new(span: Span, arg_count: usize) -> Self {
379        let mut body: ExprBody = GExprBody {
380            span,
381            locals: Locals::new(arg_count),
382            bound_body_regions: 0,
383            body: IndexVec::new(),
384            comments: vec![],
385        };
386        let current_block = body.body.push(BlockData {
387            statements: Default::default(),
388            terminator: Terminator::new(span, TerminatorKind::Return),
389        });
390        Self {
391            span,
392            body,
393            current_block,
394            unwind_block: None,
395        }
396    }
397
398    /// Finalize the builder by returning the built body.
399    pub fn build(mut self) -> ExprBody {
400        // Replace erased regions with fresh ones.
401        let mut freshener: IndexMap<RegionId, ()> = IndexMap::new();
402        self.body.dyn_visit_mut(|r: &mut Region| {
403            if r.is_erased() || r.is_body() {
404                *r = Region::Body(freshener.push(()));
405            }
406        });
407        self.body.bound_body_regions = freshener.slot_count();
408        // Return the built body.
409        self.body
410    }
411
412    /// Create a new local. Adds a `StorageLive` statement if the local is not one of the special
413    /// ones (return or function argument).
414    pub fn new_var(&mut self, name: Option<String>, ty: Ty) -> Place {
415        let place = self.body.locals.new_var(name, ty);
416        let local_id = place.as_local().unwrap();
417        if !self.body.locals.is_return_or_arg(local_id) {
418            self.push_statement(StatementKind::StorageLive(local_id));
419        }
420        place
421    }
422
423    /// Helper.
424    fn current_block(&mut self) -> &mut BlockData {
425        &mut self.body.body[self.current_block]
426    }
427
428    pub fn push_statement(&mut self, kind: StatementKind) {
429        let st = Statement::new(self.span, kind);
430        self.current_block().statements.push(st);
431    }
432
433    fn unwind_block(&mut self) -> BlockId {
434        *self.unwind_block.get_or_insert_with(|| {
435            self.body
436                .body
437                .push(mk_block(self.span, TerminatorKind::UnwindResume))
438        })
439    }
440
441    pub fn call(&mut self, call: Call) {
442        let next_block = self
443            .body
444            .body
445            .push(mk_block(self.span, TerminatorKind::Return));
446        let term = TerminatorKind::Call {
447            target: next_block,
448            call,
449            on_unwind: self.unwind_block(),
450        };
451        self.current_block().terminator.kind = term;
452        self.current_block = next_block;
453    }
454
455    pub fn insert_drop(&mut self, place: Place, fn_ptr: FnPtr) {
456        let next_block = self
457            .body
458            .body
459            .push(mk_block(self.span, TerminatorKind::Return));
460        let term = TerminatorKind::Drop {
461            kind: DropKind::Precise,
462            place,
463            fn_ptr,
464            target: next_block,
465            on_unwind: self.unwind_block(),
466        };
467        self.current_block().terminator.kind = term;
468        self.current_block = next_block;
469    }
470}