Skip to main content

charon_lib/transform/normalize/
transform_dyn_trait_calls.rs

1//! Transform method calls on `&dyn Trait` to vtable function pointer calls.
2//!
3//! This pass converts direct method calls on trait objects into calls through vtable
4//! function pointers. For example:
5//!
6//! ```rust,ignore
7//! let x: &dyn Trait = &obj;
8//! x.method(args);
9//! ```
10//!
11//! is transformed from:
12//! ```text
13//! @0 := call <dyn Trait as Trait>::method(x, args)
14//! ```
15//! to:
16//! ```text
17//! @0 := (move (*@receiver.ptr_metadata).method_check)(move (@receiver), move (@args)) // Call through function pointer
18//! ```
19
20use itertools::Itertools;
21
22use super::super::ctx::UllbcPass;
23use crate::{
24    errors::Error,
25    formatter::IntoFormatter,
26    pretty::FmtWithCtx,
27    raise_error, register_error,
28    transform::{TransformCtx, ctx::UllbcStatementTransformCtx},
29    ullbc_ast::*,
30};
31
32/// Transform a call to a trait method on a dyn trait object
33fn transform_dyn_trait_call(
34    ctx: &mut UllbcStatementTransformCtx<'_>,
35    call: &mut Call,
36) -> Result<(), Error> {
37    let fmt_ctx = &ctx.ctx.into_fmt();
38
39    // Detect if this call should be transformed
40    let FnOperand::Regular(fn_ptr) = &call.func else {
41        return Ok(()); // Not a regular function call
42    };
43    let FnPtrKind::Trait(trait_ref, method_name, _) = &fn_ptr.kind else {
44        return Ok(()); // Not a trait method call
45    };
46    let TraitRefKind::Dyn = &trait_ref.kind else {
47        return Ok(()); // Not a dyn trait trait call
48    };
49
50    // mono mode
51    // We fetch the specific preshim function
52    // by iterating `fun_decls` in `translated` with `trait_decl_id` and generic and associative arguments.
53    if ctx.ctx.options.monomorphize_with_hax {
54        // `receiver_types` contains associative arguments, if any.
55        let mut receiver_types = None;
56        // `types` contains generic arguments,
57        // which will be appended with `receiver_types` to form the complete list of arguments.
58        let mut types: Vec<_> = trait_ref
59            .trait_decl_ref
60            .skip_binder
61            .generics
62            .types
63            .clone()
64            .into_iter()
65            .skip(1)
66            .collect_vec();
67
68        // fetch associative arguments
69        if let Some(Operand::Copy(receiver) | Operand::Move(receiver)) = call.args.first() {
70            receiver_types = match receiver.ty().kind() {
71                TyKind::Ref(_, dyn_ty, _) | TyKind::RawPtr(dyn_ty, _) => match dyn_ty.kind() {
72                    TyKind::DynTrait(pred) => {
73                        let trait_type_constraints: Vec<_> = pred
74                            .binder
75                            .params
76                            .trait_type_constraints
77                            .clone()
78                            .into_iter()
79                            .collect();
80                        Some(
81                            trait_type_constraints
82                                .iter()
83                                .map(|ttc| ttc.skip_binder.ty.clone())
84                                .collect_vec(),
85                        )
86                    }
87                    _ => None,
88                },
89                TyKind::DynTrait(pred) => {
90                    let trait_type_constraints: Vec<_> =
91                        pred.binder.params.trait_type_constraints.iter().collect();
92                    // None
93                    Some(
94                        trait_type_constraints
95                            .iter()
96                            .map(|ttc| ttc.skip_binder.ty.clone())
97                            .collect_vec(),
98                    )
99                }
100                _ => None,
101            };
102        }
103
104        if let Some(mut receiver_types) = receiver_types {
105            types.append(&mut receiver_types);
106        }
107
108        // find the specific preshim function
109        let mut preshim = None;
110        for fun_decl in ctx.ctx.translated.fun_decls.iter() {
111            match &fun_decl.src {
112                ItemSource::VTableMethodPreShim(t_id, m_name, m_types) => {
113                    if *t_id == trait_ref.trait_id() && *m_name == *method_name && *m_types == types
114                    {
115                        preshim = Some(fun_decl);
116                    }
117                }
118                _ => {}
119            }
120        }
121
122        let Some(preshim) = preshim else {
123            panic!("MONO: preshim for {} is not translated", method_name);
124        };
125        // let preshim_fn_ptr = FnPtr::new(preshim.def_id.into(), GenericArgs::empty());
126        let preshim_args = GenericArgs::new(
127            preshim
128                .generics
129                .regions
130                .map_ref_indexed(|_, _| Region::Erased),
131            [].into(),
132            [].into(),
133            [].into(),
134        );
135        let preshim_fn_ptr = FnPtr::new(preshim.def_id.into(), preshim_args);
136        call.func = FnOperand::Regular(preshim_fn_ptr);
137
138        return Ok(());
139    }
140
141    // Get the type of the vtable struct.
142    let vtable_decl_ref: TypeDeclRef = {
143        // Get the trait declaration by its ID
144        let Some(trait_decl) = ctx.ctx.translated.trait_decls.get(trait_ref.trait_id()) else {
145            return Ok(()); // Unknown trait
146        };
147        // Get vtable ref from definition for correct ID.
148        let Some(vtable_ty) = &trait_decl.vtable else {
149            raise_error!(
150                ctx.ctx,
151                ctx.span,
152                "Found a `dyn Trait` method call for non-dyn-compatible trait `{}`!",
153                trait_ref.trait_id().with_ctx(fmt_ctx)
154            );
155        };
156        vtable_ty.clone().substitute_with_tref(trait_ref)
157    };
158    let vtable_decl_id = *vtable_decl_ref.id.as_adt().unwrap();
159    let Some(vtable_decl) = ctx.ctx.translated.type_decls.get(vtable_decl_id) else {
160        return Ok(()); // Missing data
161    };
162    if matches!(vtable_decl.kind, TypeDeclKind::Opaque) {
163        return Ok(()); // Missing data
164    }
165
166    // Retreive the method field from the vtable struct definition.
167    let method_field_name = format!("method_{}", method_name);
168    let Some((method_field_id, method_field)) =
169        vtable_decl.get_field_by_name(None, &method_field_name)
170    else {
171        let vtable_name = vtable_decl_ref.id.with_ctx(fmt_ctx).to_string();
172        raise_error!(
173            ctx.ctx,
174            ctx.span,
175            "Could not determine method index for {} in vtable {}",
176            method_name,
177            vtable_name
178        );
179    };
180    let method_ty = method_field
181        .ty
182        .clone()
183        .substitute(&vtable_decl_ref.generics);
184
185    // Get the receiver (first argument).
186    if call.args.is_empty() {
187        raise_error!(ctx.ctx, ctx.span, "Dyn trait call has no arguments!");
188    }
189    let dyn_trait_place = match &call.args[0] {
190        Operand::Copy(place) | Operand::Move(place) => place,
191        Operand::Const(_) => {
192            panic!("Unexpected constant as receiver for dyn trait method call")
193        }
194    };
195
196    // Construct the `(*ptr.ptr_metadata).method_field` place.
197    let vtable_ty = TyKind::Adt(vtable_decl_ref).into_ty();
198    let ptr_to_vtable_ty = Ty::new(TyKind::RawPtr(vtable_ty.clone(), RefKind::Shared));
199    let method_field_place = dyn_trait_place
200        .clone()
201        .project(ProjectionElem::PtrMetadata, ptr_to_vtable_ty)
202        .project(ProjectionElem::Deref, vtable_ty)
203        .project(
204            ProjectionElem::Field(FieldProjKind::Adt(vtable_decl_id, None), method_field_id),
205            method_ty,
206        );
207
208    // Transform the original call to use the function pointer
209    call.func = FnOperand::Dynamic(Operand::Copy(method_field_place));
210
211    Ok(())
212}
213
214pub struct Transform;
215impl UllbcPass for Transform {
216    fn transform_function(&self, ctx: &mut TransformCtx, decl: &mut FunDecl) {
217        decl.transform_ullbc_terminators(ctx, |ctx, term| {
218            if let TerminatorKind::Call { call, .. } = &mut term.kind {
219                let _ = transform_dyn_trait_call(ctx, call);
220            }
221        });
222    }
223}