charon_lib/transform/
hide_allocator_param.rs

1use derive_generic_visitor::*;
2use itertools::Itertools;
3use std::collections::HashSet;
4
5use crate::{ast::*, name_matcher::NamePattern};
6
7use super::{TransformCtx, ctx::TransformPass};
8
9#[derive(Visitor)]
10struct RemoveLastParamVisitor {
11    types: HashSet<TypeId>,
12}
13
14impl VisitAstMut for RemoveLastParamVisitor {
15    fn enter_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
16        if self.types.contains(&x.id) {
17            // Remove the last param.
18            x.generics.types.pop();
19        }
20    }
21}
22
23pub struct Transform;
24impl TransformPass for Transform {
25    fn transform_ctx(&self, ctx: &mut TransformCtx) {
26        if !ctx.options.hide_allocator {
27            return;
28        }
29        let types = &[
30            "alloc::boxed::Box",
31            "alloc::vec::Vec",
32            "alloc::rc::Rc",
33            "alloc::sync::Arc",
34        ];
35
36        let types: Vec<NamePattern> = types
37            .into_iter()
38            .map(|s| NamePattern::parse(s).unwrap())
39            .collect_vec();
40        let types: HashSet<TypeId> = ctx
41            .translated
42            .item_names
43            .iter()
44            .filter(|(_, name)| types.iter().any(|p| p.matches(&ctx.translated, name)))
45            .filter_map(|(id, _)| id.as_type())
46            .copied()
47            .map(TypeId::Adt)
48            .chain([TypeId::Builtin(BuiltinTy::Box)])
49            .collect();
50
51        for &id in &types {
52            if let Some(&id) = id.as_adt()
53                && let Some(tdecl) = ctx.translated.type_decls.get_mut(id)
54            {
55                struct SubstWithErrorVisitor(TypeVarId);
56                impl VarsVisitor for SubstWithErrorVisitor {
57                    fn visit_type_var(&mut self, v: TypeDbVar) -> Option<Ty> {
58                        if let DeBruijnVar::Bound(DeBruijnId::ZERO, var_id) = v
59                            && var_id == self.0
60                        {
61                            Some(TyKind::Error("removed allocator parameter".to_owned()).into_ty())
62                        } else {
63                            None
64                        }
65                    }
66                }
67                let tvar = tdecl.generics.types.pop().unwrap();
68                tdecl.visit_vars(&mut SubstWithErrorVisitor(tvar.index));
69            }
70        }
71
72        let _ = ctx
73            .translated
74            .drive_mut(&mut RemoveLastParamVisitor { types });
75    }
76}