Skip to main content

charon_driver/translate/
translate_generics.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::mem;
4
5use crate::hax;
6use crate::hax::{BaseState, Symbol};
7use rustc_middle::ty;
8
9use super::translate_ctx::{ItemTransCtx, TraitImplSource, TransItemSourceKind};
10use charon_lib::ast::*;
11use charon_lib::common::CycleDetector;
12use charon_lib::ids::IndexVec;
13
14/// A level of binding for type-level variables. Each item has a top-level binding level
15/// corresponding to the parameters and clauses to the items. We may then encounter inner binding
16/// levels in the following cases:
17/// - `for<..>` binders in predicates;
18/// - `fn<..>` function pointer types;
19/// - `dyn Trait` types, represented as `dyn<T: Trait>`;
20/// - types in a trait declaration or implementation block;
21/// - methods in a trait declaration or implementation block.
22///
23/// At each level, we store two things: a `GenericParams` that contains the parameters bound at
24/// this level, and various maps from the rustc-internal indices to our indices.
25#[derive(Debug, Default)]
26pub(crate) struct BindingLevel {
27    /// The parameters and predicates bound at this level.
28    pub params: GenericParams,
29    /// Rust makes the distinction between early and late-bound region parameters. We do not make
30    /// this distinction, and merge early and late bound regions. For details, see:
31    /// <https://smallcultfollowing.com/babysteps/blog/2013/10/29/intermingled-parameter-lists/>
32    /// <https://smallcultfollowing.com/babysteps/blog/2013/11/04/intermingled-parameter-lists/>
33    ///
34    /// The map from rust early regions to translated region indices.
35    pub early_region_vars: HashMap<hax::EarlyParamRegion, RegionId>,
36    /// The map from rust late/bound regions to translated region indices.
37    pub bound_region_vars: Vec<RegionId>,
38    /// Region added for the lifetime bound in the signature of the `call`/`call_mut` methods.
39    pub closure_call_method_region: Option<RegionId>,
40    /// The map from rust type variable indices to translated type variable indices.
41    pub type_vars_map: HashMap<u32, TypeVarId>,
42    /// The map from rust const generic variables to translated const generic variable indices.
43    pub const_generic_vars_map: HashMap<u32, ConstGenericVarId>,
44    /// The map from trait predicates to translated trait clause indices.
45    pub trait_preds: HashMap<hax::GenericPredicateId, TraitClauseId>,
46    /// The types of the captured variables, when we're translating a closure item. This is
47    /// translated early because this translation requires adding new lifetime generics to the
48    /// current binder.
49    pub closure_upvar_tys: Option<IndexVec<FieldId, Ty>>,
50    /// The regions we added for the upvars.
51    pub closure_upvar_regions: Vec<RegionId>,
52    /// RPITIT can cause region names to be reused. To avoid clashes in our output, we rename
53    /// duplicate names.
54    pub used_region_names: HashSet<Symbol>,
55    /// Cache the translation of types. This harnesses the deduplication of `Ty` that hax does.
56    // Important: we can't reuse type caches from earlier binders as the new binder may change what
57    // a given variable resolves to.
58    pub type_trans_cache: HashMap<hax::Ty, Ty>,
59}
60
61/// Small helper: we ignore some region names (when they are equal to "'_")
62fn translate_region_name(s: hax::Symbol) -> Option<String> {
63    let s = s.to_string();
64    if s == "'_" { None } else { Some(s) }
65}
66
67impl BindingLevel {
68    pub(crate) fn new() -> Self {
69        Self {
70            ..Default::default()
71        }
72    }
73
74    /// Important: we must push all the early-bound regions before pushing any other region.
75    pub(crate) fn push_early_region(
76        &mut self,
77        region: hax::EarlyParamRegion,
78        mutability: LifetimeMutability,
79    ) -> RegionId {
80        let name = if self.used_region_names.insert(region.name) {
81            translate_region_name(region.name)
82        } else {
83            None
84        };
85        // Check that there are no late-bound regions
86        assert!(
87            self.bound_region_vars.is_empty(),
88            "Early regions must be translated before late ones"
89        );
90        let rid = self.params.regions.push_with(|index| RegionParam {
91            index,
92            name,
93            mutability,
94        });
95        self.early_region_vars.insert(region, rid);
96        rid
97    }
98
99    /// Important: we must push all the early-bound regions before pushing any other region.
100    pub(crate) fn push_bound_region(&mut self, region: hax::BoundRegionKind) -> RegionId {
101        use crate::hax::BoundRegionKind::*;
102        let name = match region {
103            Anon => None,
104            NamedForPrinting(symbol) | Named(_, symbol) => translate_region_name(symbol),
105            ClosureEnv => Some("@env".to_owned()),
106        };
107        let rid = self
108            .params
109            .regions
110            .push_with(|index| RegionParam::new(index, name));
111        self.bound_region_vars.push(rid);
112        rid
113    }
114
115    /// Add a region for an upvar in a closure.
116    pub fn push_upvar_region(&mut self) -> RegionId {
117        // We musn't push to `bound_region_vars` because that will contain the higher-kinded
118        // signature lifetimes (if any) and they must be lookup-able.
119        let region_id = self
120            .params
121            .regions
122            .push_with(|index| RegionParam::new(index, None));
123        self.closure_upvar_regions.push(region_id);
124        region_id
125    }
126
127    pub(crate) fn push_type_var(&mut self, rid: u32, name: hax::Symbol) -> TypeVarId {
128        // Type vars comping from `impl Trait` arguments have as their name the whole `impl Trait`
129        // expression. We turn it into an identifier.
130        let mut name = name.to_string();
131        if name
132            .chars()
133            .any(|c| !(c.is_ascii_alphanumeric() || c == '_'))
134        {
135            name = format!("T{rid}")
136        }
137        let var_id = self
138            .params
139            .types
140            .push_with(|index| TypeParam { index, name });
141        self.type_vars_map.insert(rid, var_id);
142        var_id
143    }
144
145    pub(crate) fn push_const_generic_var(&mut self, rid: u32, ty: Ty, name: hax::Symbol) {
146        let var_id = self
147            .params
148            .const_generics
149            .push_with(|index| ConstGenericParam {
150                index,
151                name: name.to_string(),
152                ty,
153            });
154        self.const_generic_vars_map.insert(rid, var_id);
155    }
156
157    /// Translate a binder of regions by appending the stored reguions to the given vector.
158    pub(crate) fn push_params_from_binder(&mut self, binder: hax::Binder<()>) -> Result<(), Error> {
159        assert!(
160            self.bound_region_vars.is_empty(),
161            "Trying to use two binders at the same binding level"
162        );
163        use crate::hax::BoundVariableKind::*;
164        for p in binder.bound_vars {
165            match p {
166                Region(region) => {
167                    self.push_bound_region(region);
168                }
169                Ty(_) => {
170                    panic!("Unexpected locally bound type variable");
171                }
172                Const => {
173                    panic!("Unexpected locally bound const generic variable");
174                }
175            }
176        }
177        Ok(())
178    }
179}
180
181impl<'tcx, 'ctx> ItemTransCtx<'tcx, 'ctx> {
182    /// Get the only binding level. Panics if there are other binding levels.
183    pub(crate) fn the_only_binder(&self) -> &BindingLevel {
184        assert_eq!(self.binding_levels.len(), 1);
185        self.innermost_binder()
186    }
187    /// Get the only binding level. Panics if there are other binding levels.
188    pub(crate) fn the_only_binder_mut(&mut self) -> &mut BindingLevel {
189        assert_eq!(self.binding_levels.len(), 1);
190        self.innermost_binder_mut()
191    }
192
193    pub(crate) fn outermost_binder(&self) -> &BindingLevel {
194        self.binding_levels.outermost()
195    }
196    pub(crate) fn outermost_binder_mut(&mut self) -> &mut BindingLevel {
197        self.binding_levels.outermost_mut()
198    }
199    pub(crate) fn innermost_binder(&self) -> &BindingLevel {
200        self.binding_levels.innermost()
201    }
202    pub(crate) fn innermost_binder_mut(&mut self) -> &mut BindingLevel {
203        self.binding_levels.innermost_mut()
204    }
205
206    pub(crate) fn outermost_generics(&self) -> &GenericParams {
207        &self.outermost_binder().params
208    }
209    #[expect(dead_code)]
210    pub(crate) fn outermost_generics_mut(&mut self) -> &mut GenericParams {
211        &mut self.outermost_binder_mut().params
212    }
213    pub(crate) fn innermost_generics(&self) -> &GenericParams {
214        &self.innermost_binder().params
215    }
216    pub(crate) fn innermost_generics_mut(&mut self) -> &mut GenericParams {
217        &mut self.innermost_binder_mut().params
218    }
219
220    pub(crate) fn lookup_bound_region(
221        &mut self,
222        span: Span,
223        dbid: hax::DebruijnIndex,
224        var: hax::BoundVar,
225    ) -> Result<RegionDbVar, Error> {
226        let dbid = DeBruijnId::new(dbid);
227        if let Some(rid) = self
228            .binding_levels
229            .get(dbid)
230            .and_then(|bl| bl.bound_region_vars.get(var))
231        {
232            Ok(DeBruijnVar::bound(dbid, *rid))
233        } else {
234            raise_error!(
235                self,
236                span,
237                "Unexpected error: could not find region '{dbid}_{var}"
238            )
239        }
240    }
241
242    pub(crate) fn lookup_param<Id: Copy>(
243        &mut self,
244        span: Span,
245        f: impl for<'a> Fn(&'a BindingLevel) -> Option<Id>,
246        mk_err: impl FnOnce() -> String,
247    ) -> Result<DeBruijnVar<Id>, Error> {
248        for (dbid, bl) in self.binding_levels.iter_enumerated() {
249            if let Some(id) = f(bl) {
250                return Ok(DeBruijnVar::bound(dbid, id));
251            }
252        }
253        let err = mk_err();
254        raise_error!(self, span, "Unexpected error: could not find {}", err)
255    }
256
257    pub(crate) fn lookup_early_region(
258        &mut self,
259        span: Span,
260        region: &hax::EarlyParamRegion,
261    ) -> Result<RegionDbVar, Error> {
262        self.lookup_param(
263            span,
264            |bl| bl.early_region_vars.get(region).copied(),
265            || format!("the region variable {region:?}"),
266        )
267    }
268
269    pub(crate) fn lookup_type_var(
270        &mut self,
271        span: Span,
272        param: &hax::ParamTy,
273    ) -> Result<TypeDbVar, Error> {
274        self.lookup_param(
275            span,
276            |bl| bl.type_vars_map.get(&param.index).copied(),
277            || format!("the type variable {}", param.name),
278        )
279    }
280
281    pub(crate) fn lookup_const_generic_var(
282        &mut self,
283        span: Span,
284        param: &hax::ParamConst,
285    ) -> Result<ConstGenericDbVar, Error> {
286        self.lookup_param(
287            span,
288            |bl| bl.const_generic_vars_map.get(&param.index).copied(),
289            || format!("the const generic variable {}", param.name),
290        )
291    }
292
293    pub(crate) fn lookup_clause_var(
294        &mut self,
295        span: Span,
296        id: &hax::GenericPredicateId,
297    ) -> Result<ClauseDbVar, Error> {
298        self.lookup_param(
299            span,
300            |bl| bl.trait_preds.get(id).copied(),
301            || format!("the trait clause variable {id:?}"),
302        )
303    }
304
305    pub(crate) fn push_generic_params(&mut self, generics: &hax::TyGenerics) -> Result<(), Error> {
306        for param in &generics.params {
307            self.push_generic_param(param)?;
308        }
309        Ok(())
310    }
311
312    pub(crate) fn push_generic_param(&mut self, param: &hax::GenericParamDef) -> Result<(), Error> {
313        match &param.kind {
314            hax::GenericParamDefKind::Lifetime => {
315                let region = hax::EarlyParamRegion {
316                    index: param.index,
317                    name: param.name,
318                };
319                let mutability = self
320                    .t_ctx
321                    .lt_mutability_computer
322                    .compute_lifetime_mutability(
323                        &self.hax_state,
324                        self.item_src.def_id(),
325                        param.index,
326                    );
327                let _ = self
328                    .innermost_binder_mut()
329                    .push_early_region(region, mutability);
330            }
331            hax::GenericParamDefKind::Type { .. } => {
332                let _ = self
333                    .innermost_binder_mut()
334                    .push_type_var(param.index, param.name);
335            }
336            hax::GenericParamDefKind::Const { ty, .. } => {
337                let span = self.def_span(&param.def_id);
338                // The type should be primitive, meaning it shouldn't contain variables,
339                // non-primitive adts, etc. As a result, we can use an empty context.
340                let ty = self.translate_ty(span, ty)?;
341                self.innermost_binder_mut()
342                    .push_const_generic_var(param.index, ty, param.name);
343            }
344        }
345
346        Ok(())
347    }
348
349    // The parameters (and in particular the lifetimes) are split between
350    // early bound and late bound parameters. See those blog posts for explanations:
351    // https://smallcultfollowing.com/babysteps/blog/2013/10/29/intermingled-parameter-lists/
352    // https://smallcultfollowing.com/babysteps/blog/2013/11/04/intermingled-parameter-lists/
353    // Note that only lifetimes can be late bound at the moment.
354    //
355    // [TyCtxt.generics_of] gives us the early-bound parameters. We add the late-bound parameters
356    // here.
357    fn push_late_bound_generics_for_def(
358        &mut self,
359        _span: Span,
360        def: &hax::FullDef,
361    ) -> Result<(), Error> {
362        if let hax::FullDefKind::Fn { sig, .. } | hax::FullDefKind::AssocFn { sig, .. } = def.kind()
363        {
364            let innermost_binder = self.innermost_binder_mut();
365            assert!(innermost_binder.bound_region_vars.is_empty());
366            innermost_binder.push_params_from_binder(sig.rebind(()))?;
367        }
368        Ok(())
369    }
370
371    /// Add the generics and predicates of this item and its parents to the current context.
372    #[tracing::instrument(skip(self, span, def))]
373    fn push_generics_for_def(&mut self, span: Span, def: &hax::FullDef) -> Result<(), Error> {
374        trace!("{:?}", def.param_env());
375        // Add generics from the parent item, recursively (recursivity is important for closures,
376        // as they can be nested).
377        if let Some(parent_item) = def.typing_parent(self.hax_state()) {
378            let parent_def = self.hax_def(&parent_item)?;
379            self.push_generics_for_def(span, &parent_def)?;
380        }
381        self.push_generics_for_def_without_parents(span, def)?;
382        Ok(())
383    }
384
385    /// Add the generics and predicates of this item. This does not include the parent generics;
386    /// use `push_generics_for_def` to get the full list.
387    fn push_generics_for_def_without_parents(
388        &mut self,
389        _span: Span,
390        def: &hax::FullDef,
391    ) -> Result<(), Error> {
392        use crate::hax::FullDefKind;
393        if let Some(param_env) = def.param_env() {
394            // Add the generic params.
395            self.push_generic_params(&param_env.generics)?;
396            // Add the predicates.
397            let origin = match &def.kind {
398                FullDefKind::Adt { .. }
399                | FullDefKind::TyAlias { .. }
400                | FullDefKind::AssocTy { .. } => PredicateOrigin::WhereClauseOnType,
401                FullDefKind::Fn { .. }
402                | FullDefKind::AssocFn { .. }
403                | FullDefKind::Closure { .. }
404                | FullDefKind::Const { .. }
405                | FullDefKind::AssocConst { .. }
406                | FullDefKind::Static { .. } => PredicateOrigin::WhereClauseOnFn,
407                FullDefKind::TraitImpl { .. } | FullDefKind::InherentImpl { .. } => {
408                    PredicateOrigin::WhereClauseOnImpl
409                }
410                FullDefKind::Trait { .. } | FullDefKind::TraitAlias { .. } => {
411                    PredicateOrigin::WhereClauseOnTrait
412                }
413                _ => panic!("Unexpected def: {:?}", def.def_id().kind),
414            };
415            self.register_predicates(&param_env.predicates, origin.clone())?;
416        }
417
418        Ok(())
419    }
420
421    /// Translate the generics and predicates of this item and its parents. This adds generic
422    /// parameters and predicates to the current environment (as a binder in
423    /// `self.binding_levels`). The constructed `GenericParams` can be recovered at the end using
424    /// `self.into_generics()` and stored in the translated item.
425    ///
426    /// On top of the generics introduced by `push_generics_for_def`, this adds extra parameters
427    /// required by the `TransItemSourceKind`.
428    pub fn translate_item_generics(
429        &mut self,
430        span: Span,
431        def: &hax::FullDef,
432        kind: &TransItemSourceKind,
433    ) -> Result<(), Error> {
434        assert!(self.binding_levels.is_empty());
435        self.binding_levels.push(BindingLevel::new());
436        self.push_generics_for_def(span, def)?;
437        self.push_late_bound_generics_for_def(span, def)?;
438
439        if let hax::FullDefKind::Closure { args, .. } = def.kind() {
440            // Add the lifetime generics coming from the upvars. We translate the upvar types early
441            // to know what lifetimes are needed.
442            let upvar_tys = self.translate_closure_upvar_tys(span, args)?;
443            // Add new lifetimes params to replace the erased ones.
444            let upvar_tys = upvar_tys.replace_erased_regions(|| {
445                let region_id = self.the_only_binder_mut().push_upvar_region();
446                Region::Var(DeBruijnVar::new_at_zero(region_id))
447            });
448            self.the_only_binder_mut().closure_upvar_tys = Some(upvar_tys);
449
450            // Add the lifetime generics coming from the higher-kindedness of the signature.
451            if let TransItemSourceKind::TraitImpl(TraitImplSource::Closure(..))
452            | TransItemSourceKind::ClosureMethod(..)
453            | TransItemSourceKind::ClosureAsFnCast = kind
454            {
455                self.the_only_binder_mut()
456                    .push_params_from_binder(args.fn_sig.rebind(()))?;
457            }
458            if let TransItemSourceKind::ClosureMethod(ClosureKind::Fn | ClosureKind::FnMut) = kind {
459                // Add the lifetime generics coming from the method itself.
460                let rid = self
461                    .the_only_binder_mut()
462                    .params
463                    .regions
464                    .push_with(|index| RegionParam::new(index, None));
465                self.the_only_binder_mut().closure_call_method_region = Some(rid);
466            }
467        }
468
469        self.innermost_binder_mut().params.check_consistency();
470        Ok(())
471    }
472
473    /// Push a new binding level, run the provided function inside it, then return the bound value.
474    pub(crate) fn inside_binder<F, U>(&mut self, kind: BinderKind, f: F) -> Result<Binder<U>, Error>
475    where
476        F: FnOnce(&mut Self) -> Result<U, Error>,
477    {
478        assert!(!self.binding_levels.is_empty());
479        self.binding_levels.push(BindingLevel::new());
480
481        // Call the continuation. Important: do not short-circuit on error here.
482        let res = f(self);
483
484        // Reset
485        let params = self.binding_levels.pop().unwrap().params;
486
487        // Return
488        res.map(|skip_binder| Binder {
489            kind,
490            params,
491            skip_binder,
492        })
493    }
494
495    /// Push a new binding level corresponding to the provided `def` for the duration of the inner
496    /// function call.
497    pub(crate) fn translate_binder_for_def<F, U>(
498        &mut self,
499        span: Span,
500        kind: BinderKind,
501        def: &hax::FullDef,
502        f: F,
503    ) -> Result<Binder<U>, Error>
504    where
505        F: FnOnce(&mut Self) -> Result<U, Error>,
506    {
507        let inner_hax_state = self.t_ctx.hax_state.clone().with_hax_owner(def.def_id());
508        let outer_hax_state = mem::replace(&mut self.hax_state, inner_hax_state);
509        let ret = self.inside_binder(kind, |this| {
510            this.push_generics_for_def_without_parents(span, def)?;
511            this.push_late_bound_generics_for_def(span, def)?;
512            this.innermost_binder().params.check_consistency();
513            f(this)
514        });
515        self.hax_state = outer_hax_state;
516        ret
517    }
518
519    /// Push a group of bound regions and call the continuation.
520    /// We use this when diving into a `for<'a>`, or inside an arrow type (because
521    /// it contains universally quantified regions).
522    pub(crate) fn translate_region_binder<F, T, U>(
523        &mut self,
524        _span: Span,
525        binder: &hax::Binder<T>,
526        f: F,
527    ) -> Result<RegionBinder<U>, Error>
528    where
529        F: FnOnce(&mut Self, &T) -> Result<U, Error>,
530    {
531        let binder = self.inside_binder(BinderKind::Other, |this| {
532            this.innermost_binder_mut()
533                .push_params_from_binder(binder.rebind(()))?;
534            f(this, binder.hax_skip_binder_ref())
535        })?;
536        // Convert to a region-only binder.
537        Ok(RegionBinder {
538            regions: binder.params.regions,
539            skip_binder: binder.skip_binder,
540        })
541    }
542
543    pub(crate) fn into_generics(mut self) -> GenericParams {
544        assert!(self.binding_levels.len() == 1);
545        self.binding_levels.pop().unwrap().params
546    }
547}
548
549/// Struct to compute the "mutability" of each lifetime.
550#[derive(Default)]
551pub struct LifetimeMutabilityComputer {
552    lt_mutability: HashMap<hax::DefId, CycleDetector<HashSet<u32>>>,
553}
554
555impl LifetimeMutabilityComputer {
556    /// Compute the mutability of one lifetime.
557    pub(crate) fn compute_lifetime_mutability<'tcx>(
558        &mut self,
559        s: &impl BaseState<'tcx>,
560        item: &hax::DefId,
561        index: u32,
562    ) -> LifetimeMutability {
563        match self.compute_lifetime_mutabilities(s, item) {
564            Some(set) => {
565                if set.contains(&index) {
566                    LifetimeMutability::Mutable
567                } else {
568                    LifetimeMutability::Shared
569                }
570            }
571            None => LifetimeMutability::Unknown,
572        }
573    }
574
575    /// Compute the "mutability" of each lifetime, i.e. whether this lifetime is used in a `&'a mut
576    /// T` type or not. Returns a set of the known-mutable lifetimes for this ADT.
577    fn compute_lifetime_mutabilities<'tcx>(
578        &mut self,
579        s: &impl BaseState<'tcx>,
580        item: &hax::DefId,
581    ) -> Option<&HashSet<u32>> {
582        if !matches!(
583            item.kind,
584            hax::DefKind::Struct | hax::DefKind::Enum | hax::DefKind::Union
585        ) {
586            return None;
587        }
588        if self
589            .lt_mutability
590            .entry(item.clone())
591            .or_default()
592            .start_processing()
593        {
594            use crate::hax::SInto;
595            use ty::{TypeSuperVisitable, TypeVisitable};
596
597            struct LtMutabilityVisitor<'a, S> {
598                s: &'a S,
599                computer: &'a mut LifetimeMutabilityComputer,
600                set: HashSet<u32>,
601            }
602            impl<'tcx, S: BaseState<'tcx>> ty::TypeVisitor<ty::TyCtxt<'tcx>> for LtMutabilityVisitor<'_, S> {
603                fn visit_ty(&mut self, ty: ty::Ty<'tcx>) {
604                    match ty.kind() {
605                        ty::Ref(r, _, ty::Mutability::Mut)
606                            if let ty::RegionKind::ReEarlyParam(r) = r.kind() =>
607                        {
608                            self.set.insert(r.index);
609                        }
610                        ty::Adt(adt, args) => {
611                            let item = adt.did().sinto(self.s);
612                            if let Some(mutabilities) =
613                                self.computer.compute_lifetime_mutabilities(self.s, &item)
614                            {
615                                for arg in args.iter() {
616                                    if let Some(r) = arg.as_region()
617                                        && let ty::RegionKind::ReEarlyParam(r) = r.kind()
618                                        && mutabilities.contains(&r.index)
619                                    {
620                                        self.set.insert(r.index);
621                                    }
622                                }
623                            }
624                        }
625                        _ => {}
626                    }
627                    ty.super_visit_with(self)
628                }
629            }
630            let mut visitor = LtMutabilityVisitor {
631                s,
632                computer: self,
633                set: HashSet::new(),
634            };
635
636            let tcx = s.base().tcx;
637            let def_id = item.real_rust_def_id();
638            let adt_def = tcx.adt_def(def_id);
639            let generics = ty::GenericArgs::identity_for_item(tcx, def_id);
640            for variant in adt_def.variants() {
641                for field in &variant.fields {
642                    field.ty(tcx, generics).visit_with(&mut visitor);
643                }
644            }
645            let set = visitor.set;
646
647            self.lt_mutability
648                .get_mut(item)
649                .unwrap()
650                .done_processing(set);
651        }
652        self.lt_mutability.get(item)?.as_processed()
653    }
654}