charon_lib/transform/
skip_trait_refs_when_known.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use crate::{register_error, transform::TransformCtx, ullbc_ast::*};

use super::ctx::UllbcPass;

fn transform_call(ctx: &mut TransformCtx, span: Span, call: &mut Call) {
    // We find calls to a trait method where the impl is known; otherwise we return.
    let FnOperand::Regular(fn_ptr) = &mut call.func else {
        return;
    };
    let FunIdOrTraitMethodRef::Trait(trait_ref, name, _) = &fn_ptr.func else {
        return;
    };
    let TraitRefKind::TraitImpl(impl_id, impl_generics) = &trait_ref.kind else {
        return;
    };
    let Some(trait_impl) = &ctx.translated.trait_impls.get(*impl_id) else {
        return;
    };
    // Find the function declaration corresponding to this impl.
    let Some((_, bound_fn)) = trait_impl
        .required_methods
        .iter()
        .chain(trait_impl.provided_methods.iter())
        .find(|(n, _)| n == name)
    else {
        return;
    };
    let method_generics = &fn_ptr.generics;

    if !method_generics.matches(&bound_fn.params) {
        register_error!(
            ctx,
            span,
            "Mismatched method generics:\nparams:   {:?}\nsupplied: {:?}",
            bound_fn.params,
            method_generics
        );
    }

    // Make the two levels of binding explicit: outer binder for the impl block, inner binder for
    // the method.
    let fn_ref: Binder<Binder<FunDeclRef>> = Binder::new(
        BinderKind::Other,
        trait_impl.generics.clone(),
        bound_fn.clone(),
    );
    // Substitute the appropriate generics into the function call.
    let fn_ref = fn_ref.apply(impl_generics).apply(method_generics);
    fn_ptr.generics = fn_ref.generics;
    fn_ptr.func = FunIdOrTraitMethodRef::Fun(FunId::Regular(fn_ref.id));
}

pub struct Transform;
impl UllbcPass for Transform {
    fn transform_body(&self, ctx: &mut TransformCtx, b: &mut ExprBody) {
        for block in b.body.iter_mut() {
            for st in block.statements.iter_mut() {
                if let RawStatement::Call(call) = &mut st.content {
                    transform_call(ctx, st.span, call)
                };
            }
        }
    }
}