charon_lib/transform/
duplicate_defaulted_methods.rs

1//! Add missing methods to trait impls by duplicating the default method.
2use std::{collections::HashMap, mem};
3
4use crate::ast::*;
5
6use super::{ctx::TransformPass, TransformCtx};
7
8pub struct Transform;
9impl TransformPass for Transform {
10    fn transform_ctx(&self, ctx: &mut TransformCtx) {
11        for impl_id in ctx.translated.trait_impls.all_indices() {
12            let Some(timpl) = ctx.translated.trait_impls.get_mut(impl_id) else {
13                continue;
14            };
15            let Some(tdecl) = ctx.translated.trait_decls.get(timpl.impl_trait.trait_id) else {
16                continue;
17            };
18            if tdecl.methods.len() == timpl.methods.len() {
19                continue;
20            }
21
22            // A `TraitRef` that points to this impl with the correct generics.
23            let self_impl_ref = TraitImplRef {
24                impl_id: timpl.def_id,
25                generics: Box::new(
26                    timpl
27                        .generics
28                        .identity_args(GenericsSource::item(timpl.def_id)),
29                ),
30            };
31            let self_predicate = TraitRef {
32                kind: TraitRefKind::TraitImpl(
33                    self_impl_ref.impl_id,
34                    self_impl_ref.generics.clone(),
35                ),
36                trait_decl_ref: RegionBinder::empty(timpl.impl_trait.clone()),
37            };
38            // Map of methods we already have in the impl.
39            let mut methods_map: HashMap<TraitItemName, _> =
40                mem::take(&mut timpl.methods).into_iter().collect();
41            // Borrow shared to get access to the rest of the crate.
42            let timpl = ctx.translated.trait_impls.get(impl_id).unwrap();
43            let mut methods = vec![];
44            for (name, decl_fn_ref) in &tdecl.methods {
45                if let Some(kv) = methods_map.remove_entry(name) {
46                    methods.push(kv);
47                    continue;
48                }
49                let declared_fun_id = decl_fn_ref.skip_binder.id;
50                let declared_fun_name = ctx.translated.item_name(declared_fun_id).unwrap();
51                let new_fun_name = {
52                    let mut item_name = timpl.item_meta.name.clone();
53                    item_name
54                        .name
55                        .push(declared_fun_name.name.last().unwrap().clone());
56                    item_name
57                };
58                let opacity = ctx.opacity_for_name(&new_fun_name);
59                let new_fun_id = ctx.translated.fun_decls.reserve_slot();
60                ctx.translated.all_ids.insert(new_fun_id.into());
61                ctx.translated
62                    .item_names
63                    .insert(new_fun_id.into(), new_fun_name.clone());
64
65                // Substitute the method reference to be valid in the context of the impl.
66                let bound_fn = decl_fn_ref
67                    .clone()
68                    .substitute_with_self(&timpl.impl_trait.generics, &self_predicate.kind);
69                // The new function item has for params the concatenation of impl params and method
70                // params. We build a FunDeclRef to this even if we don't end up adding the new
71                // function item below.
72                let new_fn_ref = Binder {
73                    skip_binder: FunDeclRef {
74                        id: new_fun_id,
75                        generics: Box::new(
76                            timpl
77                                .generics
78                                .identity_args_at_depth(
79                                    GenericsSource::item(timpl.def_id),
80                                    DeBruijnId::one(),
81                                )
82                                .concat(
83                                    GenericsSource::item(new_fun_id),
84                                    &bound_fn.params.identity_args_at_depth(
85                                        GenericsSource::Method(
86                                            timpl.impl_trait.trait_id.into(),
87                                            name.clone(),
88                                        ),
89                                        DeBruijnId::zero(),
90                                    ),
91                                ),
92                        ),
93                    },
94                    params: bound_fn.params.clone(),
95                    kind: bound_fn.kind.clone(),
96                };
97                methods.push((name.clone(), new_fn_ref));
98
99                if let Some(fun_decl) = ctx.translated.fun_decls.get(declared_fun_id)
100                    && !opacity.is_invisible()
101                {
102                    let bound_fn = Binder {
103                        params: timpl.generics.clone(),
104                        skip_binder: bound_fn,
105                        kind: BinderKind::Other,
106                    };
107                    // Flatten into a single binder level. This gives us the concatenated
108                    // parameters that we'll use for the new function item, and the arguments to
109                    // pass to the old function item.
110                    let bound_fn = bound_fn.flatten();
111                    // Create a copy of the provided method and update all the relevant data.
112                    let FunDecl {
113                        def_id: _,
114                        item_meta,
115                        signature,
116                        kind,
117                        is_global_initializer,
118                        body,
119                    } = fun_decl.clone();
120                    let item_meta = ItemMeta {
121                        name: new_fun_name,
122                        is_local: timpl.item_meta.is_local,
123                        opacity,
124                        ..item_meta
125                    };
126                    let signature = FunSig {
127                        generics: bound_fn.params,
128                        inputs: signature.inputs.substitute_with_self(
129                            &bound_fn.skip_binder.generics,
130                            &self_predicate.kind,
131                        ),
132                        output: signature.output.substitute_with_self(
133                            &bound_fn.skip_binder.generics,
134                            &self_predicate.kind,
135                        ),
136                        ..signature
137                    };
138                    let kind = if let ItemKind::TraitDecl {
139                        trait_ref,
140                        item_name,
141                        ..
142                    } = kind
143                    {
144                        ItemKind::TraitImpl {
145                            impl_ref: self_impl_ref.clone(),
146                            trait_ref: trait_ref.substitute_with_self(
147                                &bound_fn.skip_binder.generics,
148                                &self_predicate.kind,
149                            ),
150                            item_name,
151                            reuses_default: true,
152                        }
153                    } else {
154                        unreachable!()
155                    };
156                    let body = if opacity.is_transparent() {
157                        body.substitute_with_self(
158                            &bound_fn.skip_binder.generics,
159                            &self_predicate.kind,
160                        )
161                    } else {
162                        Err(Opaque)
163                    };
164                    ctx.translated.fun_decls.set_slot(
165                        new_fun_id,
166                        FunDecl {
167                            def_id: new_fun_id,
168                            item_meta,
169                            signature,
170                            kind,
171                            is_global_initializer,
172                            body,
173                        },
174                    );
175                }
176            }
177            let timpl = ctx.translated.trait_impls.get_mut(impl_id).unwrap();
178            timpl.methods = methods;
179        }
180    }
181}