Skip to main content

charon_lib/transform/normalize/
transform_dyn_trait_calls.rs

1//! Transform method calls on `&dyn Trait` to vtable function pointer calls.
2//!
3//! This pass converts direct method calls on trait objects into calls through vtable
4//! function pointers. For example:
5//!
6//! ```rust,ignore
7//! let x: &dyn Trait = &obj;
8//! x.method(args);
9//! ```
10//!
11//! is transformed from:
12//! ```text
13//! @0 := call <dyn Trait as Trait>::method(x, args)
14//! ```
15//! to:
16//! ```text
17//! @0 := (move (*@receiver.ptr_metadata).method_check)(move (@receiver), move (@args)) // Call through function pointer
18//! ```
19use super::super::ctx::UllbcPass;
20use crate::{
21    errors::Error,
22    formatter::IntoFormatter,
23    pretty::FmtWithCtx,
24    raise_error, register_error,
25    transform::{
26        TransformCtx,
27        ctx::{BodyTransformCtx, UllbcStatementTransformCtx},
28    },
29    ullbc_ast::*,
30};
31
32/// Transform a call to a trait method on a dyn trait object
33fn transform_dyn_trait_call(
34    ctx: &mut UllbcStatementTransformCtx<'_>,
35    call: &mut Call,
36) -> Result<(), Error> {
37    let fmt_ctx = &ctx.ctx.into_fmt();
38
39    // Detect if this call should be transformed
40    let FnOperand::Regular(fn_ptr) = &call.func else {
41        return Ok(()); // Not a regular function call
42    };
43    let FnPtrKind::Trait(trait_ref, method_id, _) = fn_ptr.kind.as_ref() else {
44        return Ok(()); // Not a trait method call
45    };
46    let TraitRefKind::Dyn = &trait_ref.kind else {
47        return Ok(()); // Not a dyn trait trait call
48    };
49
50    // Get the type of the vtable struct.
51    let vtable_decl_ref: TypeDeclRef = {
52        // Get the trait declaration by its ID
53        let Some(trait_decl) = ctx.ctx.translated.trait_decls.get(trait_ref.trait_id()) else {
54            return Ok(()); // Unknown trait
55        };
56        // Get vtable ref from definition for correct ID.
57        let Some(vtable_ty) = &trait_decl.vtable else {
58            raise_error!(
59                ctx.ctx,
60                ctx.span,
61                "Found a `dyn Trait` method call for non-dyn-compatible trait `{}`!",
62                trait_ref.trait_id().with_ctx(fmt_ctx)
63            );
64        };
65        vtable_ty.clone().substitute_with_tref(trait_ref)
66    };
67    let vtable_decl_id = *vtable_decl_ref.id.as_adt().unwrap();
68    let Some(vtable_decl) = ctx.ctx.translated.type_decls.get(vtable_decl_id) else {
69        return Ok(()); // Missing data
70    };
71
72    let TypeDeclKind::Struct(fields) = &vtable_decl.kind else {
73        return Ok(()); // Missing data
74    };
75    let ItemSource::VTableTy { field_map, .. } = &vtable_decl.src else {
76        return Ok(()); // Weird
77    };
78    // Retrieve the method field from the vtable struct definition.
79    let Some((method_field_id, _)) = field_map
80        .iter_enumerated()
81        .find(|(_, field)| **field == VTableField::Method(*method_id))
82    else {
83        let vtable_name = vtable_decl_ref.id.with_ctx(fmt_ctx).to_string();
84        raise_error!(
85            ctx.ctx,
86            ctx.span,
87            "Could not determine method index for method {} in vtable {}",
88            method_id,
89            vtable_name
90        );
91    };
92
93    let method_field = &fields[method_field_id];
94    let method_ty = method_field
95        .ty
96        .clone()
97        .substitute(&vtable_decl_ref.generics);
98
99    // Get the receiver (first argument).
100    if call.args.is_empty() {
101        raise_error!(ctx.ctx, ctx.span, "Dyn trait call has no arguments!");
102    }
103    let dyn_trait_place = match &call.args[0] {
104        Operand::Copy(place) | Operand::Move(place) => place,
105        Operand::Const(_) => {
106            panic!("Unexpected constant as receiver for dyn trait method call")
107        }
108    };
109
110    // Construct the `(*ptr.ptr_metadata).method_field` place.
111    let vtable_ty = TyKind::Adt(vtable_decl_ref).into_ty();
112    let ptr_to_vtable_ty = Ty::new(TyKind::RawPtr(vtable_ty.clone(), RefKind::Shared));
113    let method_field_place = dyn_trait_place
114        .clone()
115        .project(ProjectionElem::PtrMetadata, ptr_to_vtable_ty)
116        .project(ProjectionElem::Deref, vtable_ty)
117        .project(
118            ProjectionElem::Field(FieldProjKind::Adt(vtable_decl_id, None), method_field_id),
119            method_ty,
120        );
121
122    let fn_ptr_place = if ctx.ctx.options.monomorphize_with_hax {
123        // In mono mode, the vtable contains erased function pointers, cast to `*const ()`.
124        // This casts back to the expected signature.
125        let real_sig_ty = TyKind::FnPtr(RegionBinder::empty(FunSig {
126            is_unsafe: true,
127            inputs: call.args.iter().map(|op| op.ty().clone()).collect(),
128            output: call.dest.ty.clone(),
129        }))
130        .into_ty();
131        let fn_ptr_place = ctx.fresh_var(None, real_sig_ty);
132        let rval_cast = Rvalue::UnaryOp(
133            UnOp::Cast(CastKind::RawPtr(
134                method_field_place.ty().clone(),
135                fn_ptr_place.ty().clone(),
136            )),
137            Operand::Copy(method_field_place),
138        );
139        ctx.insert_assn_stmt(fn_ptr_place.clone(), rval_cast);
140        fn_ptr_place
141    } else {
142        method_field_place
143    };
144
145    // Transform the original call to use the function pointer
146    call.func = FnOperand::Dynamic(Operand::Copy(fn_ptr_place));
147
148    Ok(())
149}
150
151pub struct Transform;
152impl UllbcPass for Transform {
153    fn transform_function(&self, ctx: &mut TransformCtx, decl: &mut FunDecl) {
154        decl.transform_ullbc_terminators(ctx, |ctx, term| {
155            if let TerminatorKind::Call { call, .. } = &mut term.kind {
156                let _ = transform_dyn_trait_call(ctx, call);
157            }
158        });
159    }
160}