charon_lib/transform/
duplicate_defaulted_methods.rs

1//! Add missing methods to trait impls by duplicating the default method.
2use std::{collections::HashMap, mem};
3
4use crate::ast::*;
5
6use super::{ctx::TransformPass, TransformCtx};
7
8pub struct Transform;
9impl TransformPass for Transform {
10    fn transform_ctx(&self, ctx: &mut TransformCtx) {
11        for impl_id in ctx.translated.trait_impls.all_indices() {
12            let Some(timpl) = ctx.translated.trait_impls.get_mut(impl_id) else {
13                continue;
14            };
15            let Some(tdecl) = ctx.translated.trait_decls.get(timpl.impl_trait.trait_id) else {
16                continue;
17            };
18            if tdecl.methods.len() == timpl.methods.len() {
19                continue;
20            }
21
22            // A `TraitRef` that points to this impl with the correct generics.
23            let self_impl_ref = TraitImplRef {
24                impl_id: timpl.def_id,
25                generics: timpl
26                    .generics
27                    .identity_args(GenericsSource::item(timpl.def_id)),
28            };
29            let self_predicate = TraitRef {
30                kind: TraitRefKind::TraitImpl(
31                    self_impl_ref.impl_id,
32                    self_impl_ref.generics.clone(),
33                ),
34                trait_decl_ref: RegionBinder::empty(timpl.impl_trait.clone()),
35            };
36            // Map of methods we already have in the impl.
37            let mut methods_map: HashMap<TraitItemName, _> =
38                mem::take(&mut timpl.methods).into_iter().collect();
39            // Borrow shared to get access to the rest of the crate.
40            let timpl = ctx.translated.trait_impls.get(impl_id).unwrap();
41            let mut methods = vec![];
42            for (name, decl_fn_ref) in &tdecl.methods {
43                if let Some(kv) = methods_map.remove_entry(name) {
44                    methods.push(kv);
45                    continue;
46                }
47                let declared_fun_id = decl_fn_ref.skip_binder.id;
48                let declared_fun_name = ctx.translated.item_name(declared_fun_id).unwrap();
49                let new_fun_name = {
50                    let mut item_name = timpl.item_meta.name.clone();
51                    item_name
52                        .name
53                        .push(declared_fun_name.name.last().unwrap().clone());
54                    item_name
55                };
56                let opacity = ctx.opacity_for_name(&new_fun_name);
57                let new_fun_id = ctx.translated.fun_decls.reserve_slot();
58                ctx.translated.all_ids.insert(new_fun_id.into());
59                ctx.translated
60                    .item_names
61                    .insert(new_fun_id.into(), new_fun_name.clone());
62
63                // Substitute the method reference to be valid in the context of the impl.
64                let bound_fn = decl_fn_ref
65                    .clone()
66                    .substitute_with_self(&timpl.impl_trait.generics, &self_predicate.kind);
67                // The new function item has for params the concatenation of impl params and method
68                // params. We build a FunDeclRef to this even if we don't end up adding the new
69                // function item below.
70                let new_fn_ref = Binder {
71                    skip_binder: FunDeclRef {
72                        id: new_fun_id,
73                        generics: timpl
74                            .generics
75                            .identity_args_at_depth(
76                                GenericsSource::item(timpl.def_id),
77                                DeBruijnId::one(),
78                            )
79                            .concat(
80                                GenericsSource::item(new_fun_id),
81                                &bound_fn.params.identity_args_at_depth(
82                                    GenericsSource::Method(
83                                        timpl.impl_trait.trait_id.into(),
84                                        name.clone(),
85                                    ),
86                                    DeBruijnId::zero(),
87                                ),
88                            ),
89                    },
90                    params: bound_fn.params.clone(),
91                    kind: bound_fn.kind.clone(),
92                };
93                methods.push((name.clone(), new_fn_ref));
94
95                if let Some(fun_decl) = ctx.translated.fun_decls.get(declared_fun_id)
96                    && !opacity.is_invisible()
97                {
98                    let bound_fn = Binder {
99                        params: timpl.generics.clone(),
100                        skip_binder: bound_fn,
101                        kind: BinderKind::Other,
102                    };
103                    // Flatten into a single binder level. This gives us the concatenated
104                    // parameters that we'll use for the new function item, and the arguments to
105                    // pass to the old function item.
106                    let bound_fn = bound_fn.flatten();
107                    // Create a copy of the provided method and update all the relevant data.
108                    let FunDecl {
109                        def_id: _,
110                        item_meta,
111                        signature,
112                        kind,
113                        is_global_initializer,
114                        body,
115                    } = fun_decl.clone();
116                    let item_meta = ItemMeta {
117                        name: new_fun_name,
118                        is_local: timpl.item_meta.is_local,
119                        opacity,
120                        ..item_meta
121                    };
122                    let signature = FunSig {
123                        generics: bound_fn.params,
124                        inputs: signature.inputs.substitute_with_self(
125                            &bound_fn.skip_binder.generics,
126                            &self_predicate.kind,
127                        ),
128                        output: signature.output.substitute_with_self(
129                            &bound_fn.skip_binder.generics,
130                            &self_predicate.kind,
131                        ),
132                        ..signature
133                    };
134                    let kind = if let ItemKind::TraitDecl {
135                        trait_ref,
136                        item_name,
137                        ..
138                    } = kind
139                    {
140                        ItemKind::TraitImpl {
141                            impl_ref: self_impl_ref.clone(),
142                            trait_ref: trait_ref.substitute_with_self(
143                                &bound_fn.skip_binder.generics,
144                                &self_predicate.kind,
145                            ),
146                            item_name,
147                            reuses_default: true,
148                        }
149                    } else {
150                        unreachable!()
151                    };
152                    let body = if opacity.is_transparent() {
153                        body.substitute_with_self(
154                            &bound_fn.skip_binder.generics,
155                            &self_predicate.kind,
156                        )
157                    } else {
158                        Err(Opaque)
159                    };
160                    ctx.translated.fun_decls.set_slot(
161                        new_fun_id,
162                        FunDecl {
163                            def_id: new_fun_id,
164                            item_meta,
165                            signature,
166                            kind,
167                            is_global_initializer,
168                            body,
169                        },
170                    );
171                }
172            }
173            let timpl = ctx.translated.trait_impls.get_mut(impl_id).unwrap();
174            timpl.methods = methods;
175        }
176    }
177}