charon_lib/transform/normalize/
transform_dyn_trait_calls.rs

1//! Transform method calls on `&dyn Trait` to vtable function pointer calls.
2//!
3//! This pass converts direct method calls on trait objects into calls through vtable
4//! function pointers. For example:
5//!
6//! ```rust,ignore
7//! let x: &dyn Trait = &obj;
8//! x.method(args);
9//! ```
10//!
11//! is transformed from:
12//! ```text
13//! @0 := call <dyn Trait as Trait>::method(x, args)
14//! ```
15//! to:
16//! ```text
17//! @0 := (move (*@receiver.ptr_metadata).method_check)(move (@receiver), move (@args)) // Call through function pointer
18//! ```
19
20use super::super::ctx::UllbcPass;
21use crate::{
22    errors::Error,
23    formatter::IntoFormatter,
24    pretty::FmtWithCtx,
25    raise_error, register_error,
26    transform::{TransformCtx, ctx::UllbcStatementTransformCtx},
27    ullbc_ast::*,
28};
29
30/// Transform a call to a trait method on a dyn trait object
31fn transform_dyn_trait_call(
32    ctx: &mut UllbcStatementTransformCtx<'_>,
33    call: &mut Call,
34) -> Result<(), Error> {
35    let fmt_ctx = &ctx.ctx.into_fmt();
36
37    // Detect if this call should be transformed
38    let FnOperand::Regular(fn_ptr) = &call.func else {
39        return Ok(()); // Not a regular function call
40    };
41    let FnPtrKind::Trait(trait_ref, method_name, _) = &fn_ptr.kind else {
42        return Ok(()); // Not a trait method call
43    };
44    let TraitRefKind::Dyn = &trait_ref.kind else {
45        return Ok(()); // Not a dyn trait trait call
46    };
47    let trait_pred = trait_ref.trait_decl_ref.clone().erase();
48
49    // Get the type of the vtable struct.
50    let vtable_decl_ref: TypeDeclRef = {
51        // Get the trait declaration by its ID
52        let Some(trait_decl) = ctx.ctx.translated.trait_decls.get(trait_pred.id) else {
53            return Ok(()); // Unknown trait
54        };
55        // Get vtable ref from definition for correct ID.
56        let Some(vtable_ty) = &trait_decl.vtable else {
57            raise_error!(
58                ctx.ctx,
59                ctx.span,
60                "Found a `dyn Trait` method call for non-dyn-compatible trait `{}`!",
61                trait_pred.id.with_ctx(fmt_ctx)
62            );
63        };
64        vtable_ty
65            .clone()
66            .substitute_with_self(&trait_pred.generics, &trait_ref.kind)
67    };
68    let vtable_decl_id = *vtable_decl_ref.id.as_adt().unwrap();
69    let Some(vtable_decl) = ctx.ctx.translated.type_decls.get(vtable_decl_id) else {
70        return Ok(()); // Missing data
71    };
72    if matches!(vtable_decl.kind, TypeDeclKind::Opaque) {
73        return Ok(()); // Missing data
74    }
75
76    // Retreive the method field from the vtable struct definition.
77    let method_field_name = format!("method_{}", method_name);
78    let Some((method_field_id, method_field)) =
79        vtable_decl.get_field_by_name(None, &method_field_name)
80    else {
81        let vtable_name = vtable_decl_ref.id.with_ctx(fmt_ctx).to_string();
82        raise_error!(
83            ctx.ctx,
84            ctx.span,
85            "Could not determine method index for {} in vtable {}",
86            method_name,
87            vtable_name
88        );
89    };
90    let method_ty = method_field
91        .ty
92        .clone()
93        .substitute(&vtable_decl_ref.generics);
94
95    // Get the receiver (first argument).
96    if call.args.is_empty() {
97        raise_error!(ctx.ctx, ctx.span, "Dyn trait call has no arguments!");
98    }
99    let dyn_trait_place = match &call.args[0] {
100        Operand::Copy(place) | Operand::Move(place) => place,
101        Operand::Const(_) => {
102            panic!("Unexpected constant as receiver for dyn trait method call")
103        }
104    };
105
106    // Construct the `(*ptr.ptr_metadata).method_field` place.
107    let vtable_ty = TyKind::Adt(vtable_decl_ref).into_ty();
108    let ptr_to_vtable_ty = Ty::new(TyKind::RawPtr(vtable_ty.clone(), RefKind::Shared));
109    let method_field_place = dyn_trait_place
110        .clone()
111        .project(ProjectionElem::PtrMetadata, ptr_to_vtable_ty)
112        .project(ProjectionElem::Deref, vtable_ty)
113        .project(
114            ProjectionElem::Field(FieldProjKind::Adt(vtable_decl_id, None), method_field_id),
115            method_ty,
116        );
117
118    // Transform the original call to use the function pointer
119    call.func = FnOperand::Dynamic(Operand::Copy(method_field_place));
120
121    Ok(())
122}
123
124pub struct Transform;
125impl UllbcPass for Transform {
126    fn transform_function(&self, ctx: &mut TransformCtx, decl: &mut FunDecl) {
127        decl.transform_ullbc_terminators(ctx, |ctx, term| {
128            if let TerminatorKind::Call { call, .. } = &mut term.kind {
129                let _ = transform_dyn_trait_call(ctx, call);
130            }
131        });
132    }
133}