Skip to main content

charon_lib/transform/normalize/
skip_trait_refs_when_known.rs

1use derive_generic_visitor::Visitor;
2
3use crate::transform::ctx::UllbcPass;
4use crate::{transform::TransformCtx, ullbc_ast::*};
5
6#[derive(Visitor)]
7struct NormalizeFnPtr<'a> {
8    ctx: &'a TransformCtx,
9}
10
11impl VisitAstMut for NormalizeFnPtr<'_> {
12    fn enter_fn_ptr(&mut self, fn_ptr: &mut FnPtr) {
13        if let Some(new_fn_ptr) = normalize_default_method_call_on_known_impl(self.ctx, fn_ptr)
14            .or_else(|| normalize_method_call_on_known_impl(self.ctx, fn_ptr))
15        {
16            *fn_ptr = new_fn_ptr;
17        }
18    }
19}
20
21/// Transform `Trait::default_method<X>[impl_trait_for_X]` to a direct method call.
22fn normalize_default_method_call_on_known_impl(
23    ctx: &TransformCtx,
24    fn_ptr: &FnPtr,
25) -> Option<FnPtr> {
26    let fun_id = fn_ptr.kind.as_ref().as_fun()?.as_regular()?;
27    let fun_decl = ctx.translated.fun_decls.get(*fun_id)?;
28    let ItemSource::TraitDecl {
29        trait_ref,
30        item_id: AssocItemId::Method(method_id),
31    } = &fun_decl.src
32    else {
33        return None;
34    };
35    // If the first trait proof (for the self clause) is a known impl.
36    let impl_ref = fn_ptr
37        .generics
38        .trait_refs
39        .get(TraitClauseId::ZERO)
40        .as_ref()?
41        .kind
42        .as_trait_impl()?;
43    let method_generics = {
44        let generics = &fn_ptr.generics;
45        let trait_generics = trait_ref.generics.as_ref();
46        GenericArgs {
47            regions: generics
48                .regions
49                .clone()
50                .split_off(trait_generics.regions.len()),
51            types: generics.types.clone().split_off(trait_generics.types.len()),
52            const_generics: generics
53                .const_generics
54                .clone()
55                .split_off(trait_generics.const_generics.len()),
56            // The `+ 1` is for the self clause.
57            trait_refs: generics
58                .trait_refs
59                .clone()
60                .split_off(trait_generics.trait_refs.len() + 1),
61        }
62    };
63    normalize_method_call(ctx, impl_ref, method_id, &method_generics)
64}
65
66/// Transform `impl_trait_for_X::method` to a direct method call.
67fn normalize_method_call_on_known_impl(ctx: &TransformCtx, fn_ptr: &FnPtr) -> Option<FnPtr> {
68    let FnPtrKind::Trait(trait_ref, method_id) = fn_ptr.kind.as_ref() else {
69        return None;
70    };
71    let TraitRefKind::TraitImpl(impl_ref) = &trait_ref.kind else {
72        return None;
73    };
74    normalize_method_call(ctx, impl_ref, method_id, &fn_ptr.generics)
75}
76
77fn normalize_method_call(
78    ctx: &TransformCtx,
79    impl_ref: &TraitImplRef,
80    method_id: &TraitMethodId,
81    method_generics: &GenericArgs,
82) -> Option<FnPtr> {
83    let trait_impl = &ctx.translated.trait_impls.get(impl_ref.id)?;
84    // Find the function declaration corresponding to this impl.
85    let bound_fn = trait_impl.methods.get(*method_id)?;
86    if !method_generics.matches(&bound_fn.params) {
87        return None;
88    }
89
90    // Make the two levels of binding explicit: outer binder for the impl block, inner binder for
91    // the method.
92    let fn_ref: Binder<Binder<FunDeclRef>> = Binder::new(
93        BinderKind::Other,
94        trait_impl.generics.clone(),
95        bound_fn.clone(),
96    );
97    // Substitute the appropriate generics into the function call.
98    let fn_ref = fn_ref.apply(&impl_ref.generics).apply(method_generics);
99    Some(FnPtr::new(
100        FnPtrKind::Fun(FunId::Regular(fn_ref.id)),
101        fn_ref.generics,
102    ))
103}
104
105pub struct Transform;
106impl UllbcPass for Transform {
107    fn transform_item(&self, ctx: &mut TransformCtx, mut item: ItemRefMut<'_>) {
108        let _ = item.drive_mut(&mut NormalizeFnPtr { ctx });
109    }
110}