Skip to main content

charon_lib/transform/normalize/
transform_dyn_trait_calls.rs

1//! Transform method calls on `&dyn Trait` to vtable function pointer calls.
2//!
3//! This pass converts direct method calls on trait objects into calls through vtable
4//! function pointers. For example:
5//!
6//! ```rust,ignore
7//! let x: &dyn Trait = &obj;
8//! x.method(args);
9//! ```
10//!
11//! is transformed from:
12//! ```text
13//! @0 := call <dyn Trait as Trait>::method(x, args)
14//! ```
15//! to:
16//! ```text
17//! @0 := (move (*@receiver.ptr_metadata).method_check)(move (@receiver), move (@args)) // Call through function pointer
18//! ```
19
20use itertools::Itertools;
21
22use super::super::ctx::UllbcPass;
23use crate::{
24    errors::Error,
25    formatter::IntoFormatter,
26    pretty::FmtWithCtx,
27    raise_error, register_error,
28    transform::{TransformCtx, ctx::UllbcStatementTransformCtx},
29    ullbc_ast::*,
30};
31
32/// Transform a call to a trait method on a dyn trait object
33fn transform_dyn_trait_call(
34    ctx: &mut UllbcStatementTransformCtx<'_>,
35    call: &mut Call,
36) -> Result<(), Error> {
37    let fmt_ctx = &ctx.ctx.into_fmt();
38
39    // Detect if this call should be transformed
40    let FnOperand::Regular(fn_ptr) = &call.func else {
41        return Ok(()); // Not a regular function call
42    };
43    let FnPtrKind::Trait(trait_ref, method_name, _) = &fn_ptr.kind else {
44        return Ok(()); // Not a trait method call
45    };
46    let TraitRefKind::Dyn = &trait_ref.kind else {
47        return Ok(()); // Not a dyn trait trait call
48    };
49
50    // mono mode
51    // We fetch the specific preshim function
52    // by iterating `fun_decls` in `translated` with `trait_decl_id` and generic and associative arguments.
53    if ctx.ctx.options.monomorphize_with_hax {
54        // `receiver_types` contains associative arguments, if any.
55        let mut receiver_types = None;
56        // `types` contains generic arguments,
57        // which will be appended with `receiver_types` to form the complete list of arguments.
58        let mut types: Vec<_> = trait_ref
59            .trait_decl_ref
60            .skip_binder
61            .generics
62            .types
63            .clone()
64            .into_iter()
65            .skip(1)
66            .collect_vec();
67
68        // fetch associative arguments
69        if let Some(Operand::Copy(receiver) | Operand::Move(receiver)) = call.args.first() {
70            receiver_types = match receiver.ty().kind() {
71                TyKind::Ref(_, dyn_ty, _) | TyKind::RawPtr(dyn_ty, _) => match dyn_ty.kind() {
72                    TyKind::DynTrait(pred) => {
73                        let trait_type_constraints: Vec<_> = pred
74                            .binder
75                            .params
76                            .trait_type_constraints
77                            .clone()
78                            .into_iter()
79                            .collect();
80                        Some(
81                            trait_type_constraints
82                                .iter()
83                                .map(|ttc| ttc.skip_binder.ty.clone())
84                                .collect_vec(),
85                        )
86                    }
87                    _ => None,
88                },
89                TyKind::DynTrait(pred) => {
90                    let trait_type_constraints: Vec<_> =
91                        pred.binder.params.trait_type_constraints.iter().collect();
92                    // None
93                    Some(
94                        trait_type_constraints
95                            .iter()
96                            .map(|ttc| ttc.skip_binder.ty.clone())
97                            .collect_vec(),
98                    )
99                }
100                _ => None,
101            };
102        }
103
104        if let Some(mut receiver_types) = receiver_types {
105            types.append(&mut receiver_types);
106        }
107
108        // find the specific preshim function
109        let mut preshim = None;
110        for fun_decl in ctx.ctx.translated.fun_decls.iter() {
111            if let ItemSource::VTableMethodPreShim(t_id, m_name, m_types) = &fun_decl.src
112                && *t_id == trait_ref.trait_id()
113                && *m_name == *method_name
114                && *m_types == types
115            {
116                preshim = Some(fun_decl);
117            }
118        }
119
120        let Some(preshim) = preshim else {
121            panic!("MONO: preshim for {} is not translated", method_name);
122        };
123        // let preshim_fn_ptr = FnPtr::new(preshim.def_id.into(), GenericArgs::empty());
124        let preshim_args = GenericArgs::new(
125            preshim
126                .generics
127                .regions
128                .map_ref_indexed(|_, _| Region::Erased),
129            [].into(),
130            [].into(),
131            [].into(),
132        );
133        let preshim_fn_ptr = FnPtr::new(preshim.def_id.into(), preshim_args);
134        call.func = FnOperand::Regular(preshim_fn_ptr);
135
136        return Ok(());
137    }
138
139    // Get the type of the vtable struct.
140    let vtable_decl_ref: TypeDeclRef = {
141        // Get the trait declaration by its ID
142        let Some(trait_decl) = ctx.ctx.translated.trait_decls.get(trait_ref.trait_id()) else {
143            return Ok(()); // Unknown trait
144        };
145        // Get vtable ref from definition for correct ID.
146        let Some(vtable_ty) = &trait_decl.vtable else {
147            raise_error!(
148                ctx.ctx,
149                ctx.span,
150                "Found a `dyn Trait` method call for non-dyn-compatible trait `{}`!",
151                trait_ref.trait_id().with_ctx(fmt_ctx)
152            );
153        };
154        vtable_ty.clone().substitute_with_tref(trait_ref)
155    };
156    let vtable_decl_id = *vtable_decl_ref.id.as_adt().unwrap();
157    let Some(vtable_decl) = ctx.ctx.translated.type_decls.get(vtable_decl_id) else {
158        return Ok(()); // Missing data
159    };
160    if matches!(vtable_decl.kind, TypeDeclKind::Opaque) {
161        return Ok(()); // Missing data
162    }
163
164    // Retreive the method field from the vtable struct definition.
165    let method_field_name = format!("method_{}", method_name);
166    let Some((method_field_id, method_field)) =
167        vtable_decl.get_field_by_name(None, &method_field_name)
168    else {
169        let vtable_name = vtable_decl_ref.id.with_ctx(fmt_ctx).to_string();
170        raise_error!(
171            ctx.ctx,
172            ctx.span,
173            "Could not determine method index for {} in vtable {}",
174            method_name,
175            vtable_name
176        );
177    };
178    let method_ty = method_field
179        .ty
180        .clone()
181        .substitute(&vtable_decl_ref.generics);
182
183    // Get the receiver (first argument).
184    if call.args.is_empty() {
185        raise_error!(ctx.ctx, ctx.span, "Dyn trait call has no arguments!");
186    }
187    let dyn_trait_place = match &call.args[0] {
188        Operand::Copy(place) | Operand::Move(place) => place,
189        Operand::Const(_) => {
190            panic!("Unexpected constant as receiver for dyn trait method call")
191        }
192    };
193
194    // Construct the `(*ptr.ptr_metadata).method_field` place.
195    let vtable_ty = TyKind::Adt(vtable_decl_ref).into_ty();
196    let ptr_to_vtable_ty = Ty::new(TyKind::RawPtr(vtable_ty.clone(), RefKind::Shared));
197    let method_field_place = dyn_trait_place
198        .clone()
199        .project(ProjectionElem::PtrMetadata, ptr_to_vtable_ty)
200        .project(ProjectionElem::Deref, vtable_ty)
201        .project(
202            ProjectionElem::Field(FieldProjKind::Adt(vtable_decl_id, None), method_field_id),
203            method_ty,
204        );
205
206    // Transform the original call to use the function pointer
207    call.func = FnOperand::Dynamic(Operand::Copy(method_field_place));
208
209    Ok(())
210}
211
212pub struct Transform;
213impl UllbcPass for Transform {
214    fn transform_function(&self, ctx: &mut TransformCtx, decl: &mut FunDecl) {
215        decl.transform_ullbc_terminators(ctx, |ctx, term| {
216            if let TerminatorKind::Call { call, .. } = &mut term.kind {
217                let _ = transform_dyn_trait_call(ctx, call);
218            }
219        });
220    }
221}