charon_lib/transform/normalize/
skip_trait_refs_when_known.rs1use derive_generic_visitor::Visitor;
2
3use crate::transform::ctx::TransformPass;
4use crate::{register_error, transform::TransformCtx, ullbc_ast::*};
5
6#[derive(Visitor)]
7struct NormalizeFnPtr<'a> {
8    ctx: &'a TransformCtx,
9    span: Span,
10}
11
12impl VisitorWithSpan for NormalizeFnPtr<'_> {
13    fn current_span(&mut self) -> &mut Span {
14        &mut self.span
15    }
16}
17
18impl VisitAstMut for NormalizeFnPtr<'_> {
19    fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
20        VisitWithSpan::new(self).visit(x)
22    }
23
24    fn enter_fn_ptr(&mut self, fn_ptr: &mut FnPtr) {
25        transform_fn_ptr(self.ctx, self.span, fn_ptr);
26    }
27}
28
29fn transform_fn_ptr(ctx: &TransformCtx, span: Span, fn_ptr: &mut FnPtr) {
30    let FnPtrKind::Trait(trait_ref, name, _) = fn_ptr.kind.as_ref() else {
32        return;
33    };
34    let TraitRefKind::TraitImpl(impl_ref) = &trait_ref.kind else {
35        return;
36    };
37    let Some(trait_impl) = &ctx.translated.trait_impls.get(impl_ref.id) else {
38        return;
39    };
40    let Some((_, bound_fn)) = trait_impl.methods().find(|(n, _)| n == name) else {
42        return;
43    };
44    let method_generics = &fn_ptr.generics;
45
46    if !method_generics.matches(&bound_fn.params) {
47        register_error!(
48            ctx,
49            span,
50            "Mismatched method generics:\nparams:   {:?}\nsupplied: {:?}",
51            bound_fn.params,
52            method_generics
53        );
54    }
55
56    let fn_ref: Binder<Binder<FunDeclRef>> = Binder::new(
59        BinderKind::Other,
60        trait_impl.generics.clone(),
61        bound_fn.clone(),
62    );
63    let fn_ref = fn_ref.apply(&impl_ref.generics).apply(method_generics);
65    fn_ptr.generics = fn_ref.generics;
66    fn_ptr.kind = Box::new(FnPtrKind::Fun(FunId::Regular(fn_ref.id)));
67}
68
69pub struct Transform;
70impl TransformPass for Transform {
71    fn transform_ctx(&self, ctx: &mut TransformCtx) {
72        ctx.for_each_item_mut(|ctx, mut item| {
73            let _ = item.drive_mut(&mut NormalizeFnPtr {
74                ctx,
75                span: Span::dummy(),
76            });
77        })
78    }
79}