Skip to main content

charon_driver/translate/
translate_generics.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::mem;
4
5use hax::{BaseState, Symbol};
6use rustc_middle::ty;
7
8use super::translate_ctx::{ItemTransCtx, TraitImplSource, TransItemSourceKind};
9use charon_lib::ast::*;
10use charon_lib::common::CycleDetector;
11use charon_lib::ids::IndexVec;
12
13/// A level of binding for type-level variables. Each item has a top-level binding level
14/// corresponding to the parameters and clauses to the items. We may then encounter inner binding
15/// levels in the following cases:
16/// - `for<..>` binders in predicates;
17/// - `fn<..>` function pointer types;
18/// - `dyn Trait` types, represented as `dyn<T: Trait>`;
19/// - types in a trait declaration or implementation block;
20/// - methods in a trait declaration or implementation block.
21///
22/// At each level, we store two things: a `GenericParams` that contains the parameters bound at
23/// this level, and various maps from the rustc-internal indices to our indices.
24#[derive(Debug, Default)]
25pub(crate) struct BindingLevel {
26    /// The parameters and predicates bound at this level.
27    pub params: GenericParams,
28    /// Rust makes the distinction between early and late-bound region parameters. We do not make
29    /// this distinction, and merge early and late bound regions. For details, see:
30    /// <https://smallcultfollowing.com/babysteps/blog/2013/10/29/intermingled-parameter-lists/>
31    /// <https://smallcultfollowing.com/babysteps/blog/2013/11/04/intermingled-parameter-lists/>
32    ///
33    /// The map from rust early regions to translated region indices.
34    pub early_region_vars: HashMap<hax::EarlyParamRegion, RegionId>,
35    /// The map from rust late/bound regions to translated region indices.
36    pub bound_region_vars: Vec<RegionId>,
37    /// Region added for the lifetime bound in the signature of the `call`/`call_mut` methods.
38    pub closure_call_method_region: Option<RegionId>,
39    /// The map from rust type variable indices to translated type variable indices.
40    pub type_vars_map: HashMap<u32, TypeVarId>,
41    /// The map from rust const generic variables to translated const generic variable indices.
42    pub const_generic_vars_map: HashMap<u32, ConstGenericVarId>,
43    /// The map from trait predicates to translated trait clause indices.
44    pub trait_preds: HashMap<hax::GenericPredicateId, TraitClauseId>,
45    /// The types of the captured variables, when we're translating a closure item. This is
46    /// translated early because this translation requires adding new lifetime generics to the
47    /// current binder.
48    pub closure_upvar_tys: Option<IndexVec<FieldId, Ty>>,
49    /// The regions we added for the upvars.
50    pub closure_upvar_regions: Vec<RegionId>,
51    /// RPITIT can cause region names to be reused. To avoid clashes in our output, we rename
52    /// duplicate names.
53    pub used_region_names: HashSet<Symbol>,
54    /// Cache the translation of types. This harnesses the deduplication of `Ty` that hax does.
55    // Important: we can't reuse type caches from earlier binders as the new binder may change what
56    // a given variable resolves to.
57    pub type_trans_cache: HashMap<hax::Ty, Ty>,
58}
59
60/// Small helper: we ignore some region names (when they are equal to "'_")
61fn translate_region_name(s: hax::Symbol) -> Option<String> {
62    let s = s.to_string();
63    if s == "'_" { None } else { Some(s) }
64}
65
66impl BindingLevel {
67    pub(crate) fn new() -> Self {
68        Self {
69            ..Default::default()
70        }
71    }
72
73    /// Important: we must push all the early-bound regions before pushing any other region.
74    pub(crate) fn push_early_region(
75        &mut self,
76        region: hax::EarlyParamRegion,
77        mutability: LifetimeMutability,
78    ) -> RegionId {
79        let name = if self.used_region_names.insert(region.name) {
80            translate_region_name(region.name)
81        } else {
82            None
83        };
84        // Check that there are no late-bound regions
85        assert!(
86            self.bound_region_vars.is_empty(),
87            "Early regions must be translated before late ones"
88        );
89        let rid = self.params.regions.push_with(|index| RegionParam {
90            index,
91            name,
92            mutability,
93        });
94        self.early_region_vars.insert(region, rid);
95        rid
96    }
97
98    /// Important: we must push all the early-bound regions before pushing any other region.
99    pub(crate) fn push_bound_region(&mut self, region: hax::BoundRegionKind) -> RegionId {
100        use hax::BoundRegionKind::*;
101        let name = match region {
102            Anon => None,
103            NamedForPrinting(symbol) | Named(_, symbol) => translate_region_name(symbol),
104            ClosureEnv => Some("@env".to_owned()),
105        };
106        let rid = self
107            .params
108            .regions
109            .push_with(|index| RegionParam::new(index, name));
110        self.bound_region_vars.push(rid);
111        rid
112    }
113
114    /// Add a region for an upvar in a closure.
115    pub fn push_upvar_region(&mut self) -> RegionId {
116        // We musn't push to `bound_region_vars` because that will contain the higher-kinded
117        // signature lifetimes (if any) and they must be lookup-able.
118        let region_id = self
119            .params
120            .regions
121            .push_with(|index| RegionParam::new(index, None));
122        self.closure_upvar_regions.push(region_id);
123        region_id
124    }
125
126    pub(crate) fn push_type_var(&mut self, rid: u32, name: hax::Symbol) -> TypeVarId {
127        // Type vars comping from `impl Trait` arguments have as their name the whole `impl Trait`
128        // expression. We turn it into an identifier.
129        let mut name = name.to_string();
130        if name
131            .chars()
132            .any(|c| !(c.is_ascii_alphanumeric() || c == '_'))
133        {
134            name = format!("T{rid}")
135        }
136        let var_id = self
137            .params
138            .types
139            .push_with(|index| TypeParam { index, name });
140        self.type_vars_map.insert(rid, var_id);
141        var_id
142    }
143
144    pub(crate) fn push_const_generic_var(&mut self, rid: u32, ty: Ty, name: hax::Symbol) {
145        let var_id = self
146            .params
147            .const_generics
148            .push_with(|index| ConstGenericParam {
149                index,
150                name: name.to_string(),
151                ty,
152            });
153        self.const_generic_vars_map.insert(rid, var_id);
154    }
155
156    /// Translate a binder of regions by appending the stored reguions to the given vector.
157    pub(crate) fn push_params_from_binder(&mut self, binder: hax::Binder<()>) -> Result<(), Error> {
158        assert!(
159            self.bound_region_vars.is_empty(),
160            "Trying to use two binders at the same binding level"
161        );
162        use hax::BoundVariableKind::*;
163        for p in binder.bound_vars {
164            match p {
165                Region(region) => {
166                    self.push_bound_region(region);
167                }
168                Ty(_) => {
169                    panic!("Unexpected locally bound type variable");
170                }
171                Const => {
172                    panic!("Unexpected locally bound const generic variable");
173                }
174            }
175        }
176        Ok(())
177    }
178}
179
180impl<'tcx, 'ctx> ItemTransCtx<'tcx, 'ctx> {
181    /// Get the only binding level. Panics if there are other binding levels.
182    pub(crate) fn the_only_binder(&self) -> &BindingLevel {
183        assert_eq!(self.binding_levels.len(), 1);
184        self.innermost_binder()
185    }
186    /// Get the only binding level. Panics if there are other binding levels.
187    pub(crate) fn the_only_binder_mut(&mut self) -> &mut BindingLevel {
188        assert_eq!(self.binding_levels.len(), 1);
189        self.innermost_binder_mut()
190    }
191
192    pub(crate) fn outermost_binder(&self) -> &BindingLevel {
193        self.binding_levels.outermost()
194    }
195    pub(crate) fn outermost_binder_mut(&mut self) -> &mut BindingLevel {
196        self.binding_levels.outermost_mut()
197    }
198    pub(crate) fn innermost_binder(&self) -> &BindingLevel {
199        self.binding_levels.innermost()
200    }
201    pub(crate) fn innermost_binder_mut(&mut self) -> &mut BindingLevel {
202        self.binding_levels.innermost_mut()
203    }
204
205    pub(crate) fn outermost_generics(&self) -> &GenericParams {
206        &self.outermost_binder().params
207    }
208    #[expect(dead_code)]
209    pub(crate) fn outermost_generics_mut(&mut self) -> &mut GenericParams {
210        &mut self.outermost_binder_mut().params
211    }
212    pub(crate) fn innermost_generics(&self) -> &GenericParams {
213        &self.innermost_binder().params
214    }
215    pub(crate) fn innermost_generics_mut(&mut self) -> &mut GenericParams {
216        &mut self.innermost_binder_mut().params
217    }
218
219    pub(crate) fn lookup_bound_region(
220        &mut self,
221        span: Span,
222        dbid: hax::DebruijnIndex,
223        var: hax::BoundVar,
224    ) -> Result<RegionDbVar, Error> {
225        let dbid = DeBruijnId::new(dbid);
226        if let Some(rid) = self
227            .binding_levels
228            .get(dbid)
229            .and_then(|bl| bl.bound_region_vars.get(var))
230        {
231            Ok(DeBruijnVar::bound(dbid, *rid))
232        } else {
233            raise_error!(
234                self,
235                span,
236                "Unexpected error: could not find region '{dbid}_{var}"
237            )
238        }
239    }
240
241    pub(crate) fn lookup_param<Id: Copy>(
242        &mut self,
243        span: Span,
244        f: impl for<'a> Fn(&'a BindingLevel) -> Option<Id>,
245        mk_err: impl FnOnce() -> String,
246    ) -> Result<DeBruijnVar<Id>, Error> {
247        for (dbid, bl) in self.binding_levels.iter_enumerated() {
248            if let Some(id) = f(bl) {
249                return Ok(DeBruijnVar::bound(dbid, id));
250            }
251        }
252        let err = mk_err();
253        raise_error!(self, span, "Unexpected error: could not find {}", err)
254    }
255
256    pub(crate) fn lookup_early_region(
257        &mut self,
258        span: Span,
259        region: &hax::EarlyParamRegion,
260    ) -> Result<RegionDbVar, Error> {
261        self.lookup_param(
262            span,
263            |bl| bl.early_region_vars.get(region).copied(),
264            || format!("the region variable {region:?}"),
265        )
266    }
267
268    pub(crate) fn lookup_type_var(
269        &mut self,
270        span: Span,
271        param: &hax::ParamTy,
272    ) -> Result<TypeDbVar, Error> {
273        self.lookup_param(
274            span,
275            |bl| bl.type_vars_map.get(&param.index).copied(),
276            || format!("the type variable {}", param.name),
277        )
278    }
279
280    pub(crate) fn lookup_const_generic_var(
281        &mut self,
282        span: Span,
283        param: &hax::ParamConst,
284    ) -> Result<ConstGenericDbVar, Error> {
285        self.lookup_param(
286            span,
287            |bl| bl.const_generic_vars_map.get(&param.index).copied(),
288            || format!("the const generic variable {}", param.name),
289        )
290    }
291
292    pub(crate) fn lookup_clause_var(
293        &mut self,
294        span: Span,
295        id: &hax::GenericPredicateId,
296    ) -> Result<ClauseDbVar, Error> {
297        self.lookup_param(
298            span,
299            |bl| bl.trait_preds.get(id).copied(),
300            || format!("the trait clause variable {id:?}"),
301        )
302    }
303
304    pub(crate) fn push_generic_params(&mut self, generics: &hax::TyGenerics) -> Result<(), Error> {
305        for param in &generics.params {
306            self.push_generic_param(param)?;
307        }
308        Ok(())
309    }
310
311    pub(crate) fn push_generic_param(&mut self, param: &hax::GenericParamDef) -> Result<(), Error> {
312        match &param.kind {
313            hax::GenericParamDefKind::Lifetime => {
314                let region = hax::EarlyParamRegion {
315                    index: param.index,
316                    name: param.name.clone(),
317                };
318                let mutability = self
319                    .t_ctx
320                    .lt_mutability_computer
321                    .compute_lifetime_mutability(
322                        &self.hax_state,
323                        self.item_src.def_id(),
324                        param.index,
325                    );
326                let _ = self
327                    .innermost_binder_mut()
328                    .push_early_region(region, mutability);
329            }
330            hax::GenericParamDefKind::Type { .. } => {
331                let _ = self
332                    .innermost_binder_mut()
333                    .push_type_var(param.index, param.name);
334            }
335            hax::GenericParamDefKind::Const { ty, .. } => {
336                let span = self.def_span(&param.def_id);
337                // The type should be primitive, meaning it shouldn't contain variables,
338                // non-primitive adts, etc. As a result, we can use an empty context.
339                let ty = self.translate_ty(span, ty)?;
340                self.innermost_binder_mut()
341                    .push_const_generic_var(param.index, ty, param.name);
342            }
343        }
344
345        Ok(())
346    }
347
348    // The parameters (and in particular the lifetimes) are split between
349    // early bound and late bound parameters. See those blog posts for explanations:
350    // https://smallcultfollowing.com/babysteps/blog/2013/10/29/intermingled-parameter-lists/
351    // https://smallcultfollowing.com/babysteps/blog/2013/11/04/intermingled-parameter-lists/
352    // Note that only lifetimes can be late bound at the moment.
353    //
354    // [TyCtxt.generics_of] gives us the early-bound parameters. We add the late-bound parameters
355    // here.
356    fn push_late_bound_generics_for_def(
357        &mut self,
358        _span: Span,
359        def: &hax::FullDef,
360    ) -> Result<(), Error> {
361        if let hax::FullDefKind::Fn { sig, .. } | hax::FullDefKind::AssocFn { sig, .. } = def.kind()
362        {
363            let innermost_binder = self.innermost_binder_mut();
364            assert!(innermost_binder.bound_region_vars.is_empty());
365            innermost_binder.push_params_from_binder(sig.rebind(()))?;
366        }
367        Ok(())
368    }
369
370    /// Add the generics and predicates of this item and its parents to the current context.
371    #[tracing::instrument(skip(self, span, def))]
372    fn push_generics_for_def(&mut self, span: Span, def: &hax::FullDef) -> Result<(), Error> {
373        trace!("{:?}", def.param_env());
374        // Add generics from the parent item, recursively (recursivity is important for closures,
375        // as they can be nested).
376        if let Some(parent_item) = def.typing_parent(self.hax_state()) {
377            let parent_def = self.hax_def(&parent_item)?;
378            self.push_generics_for_def(span, &parent_def)?;
379        }
380        self.push_generics_for_def_without_parents(span, def)?;
381        Ok(())
382    }
383
384    /// Add the generics and predicates of this item. This does not include the parent generics;
385    /// use `push_generics_for_def` to get the full list.
386    fn push_generics_for_def_without_parents(
387        &mut self,
388        _span: Span,
389        def: &hax::FullDef,
390    ) -> Result<(), Error> {
391        use hax::FullDefKind;
392        if let Some(param_env) = def.param_env() {
393            // Add the generic params.
394            self.push_generic_params(&param_env.generics)?;
395            // Add the predicates.
396            let origin = match &def.kind {
397                FullDefKind::Adt { .. }
398                | FullDefKind::TyAlias { .. }
399                | FullDefKind::AssocTy { .. } => PredicateOrigin::WhereClauseOnType,
400                FullDefKind::Fn { .. }
401                | FullDefKind::AssocFn { .. }
402                | FullDefKind::Closure { .. }
403                | FullDefKind::Const { .. }
404                | FullDefKind::AssocConst { .. }
405                | FullDefKind::Static { .. } => PredicateOrigin::WhereClauseOnFn,
406                FullDefKind::TraitImpl { .. } | FullDefKind::InherentImpl { .. } => {
407                    PredicateOrigin::WhereClauseOnImpl
408                }
409                FullDefKind::Trait { .. } | FullDefKind::TraitAlias { .. } => {
410                    PredicateOrigin::WhereClauseOnTrait
411                }
412                _ => panic!("Unexpected def: {:?}", def.def_id().kind),
413            };
414            self.register_predicates(&param_env.predicates, origin.clone())?;
415        }
416
417        Ok(())
418    }
419
420    /// Translate the generics and predicates of this item and its parents. This adds generic
421    /// parameters and predicates to the current environment (as a binder in
422    /// `self.binding_levels`). The constructed `GenericParams` can be recovered at the end using
423    /// `self.into_generics()` and stored in the translated item.
424    ///
425    /// On top of the generics introduced by `push_generics_for_def`, this adds extra parameters
426    /// required by the `TransItemSourceKind`.
427    pub fn translate_item_generics(
428        &mut self,
429        span: Span,
430        def: &hax::FullDef,
431        kind: &TransItemSourceKind,
432    ) -> Result<(), Error> {
433        assert!(self.binding_levels.len() == 0);
434        self.binding_levels.push(BindingLevel::new());
435        self.push_generics_for_def(span, def)?;
436        self.push_late_bound_generics_for_def(span, def)?;
437
438        if let hax::FullDefKind::Closure { args, .. } = def.kind() {
439            // Add the lifetime generics coming from the upvars. We translate the upvar types early
440            // to know what lifetimes are needed.
441            let upvar_tys = self.translate_closure_upvar_tys(span, args)?;
442            // Add new lifetimes params to replace the erased ones.
443            let upvar_tys = upvar_tys.replace_erased_regions(|| {
444                let region_id = self.the_only_binder_mut().push_upvar_region();
445                Region::Var(DeBruijnVar::new_at_zero(region_id))
446            });
447            self.the_only_binder_mut().closure_upvar_tys = Some(upvar_tys);
448
449            // Add the lifetime generics coming from the higher-kindedness of the signature.
450            if let TransItemSourceKind::TraitImpl(TraitImplSource::Closure(..))
451            | TransItemSourceKind::ClosureMethod(..)
452            | TransItemSourceKind::ClosureAsFnCast = kind
453            {
454                self.the_only_binder_mut()
455                    .push_params_from_binder(args.fn_sig.rebind(()))?;
456            }
457            if let TransItemSourceKind::ClosureMethod(ClosureKind::Fn | ClosureKind::FnMut) = kind {
458                // Add the lifetime generics coming from the method itself.
459                let rid = self
460                    .the_only_binder_mut()
461                    .params
462                    .regions
463                    .push_with(|index| RegionParam::new(index, None));
464                self.the_only_binder_mut().closure_call_method_region = Some(rid);
465            }
466        }
467
468        self.innermost_binder_mut().params.check_consistency();
469        Ok(())
470    }
471
472    /// Push a new binding level, run the provided function inside it, then return the bound value.
473    pub(crate) fn inside_binder<F, U>(&mut self, kind: BinderKind, f: F) -> Result<Binder<U>, Error>
474    where
475        F: FnOnce(&mut Self) -> Result<U, Error>,
476    {
477        assert!(!self.binding_levels.is_empty());
478        self.binding_levels.push(BindingLevel::new());
479
480        // Call the continuation. Important: do not short-circuit on error here.
481        let res = f(self);
482
483        // Reset
484        let params = self.binding_levels.pop().unwrap().params;
485
486        // Return
487        res.map(|skip_binder| Binder {
488            kind,
489            params,
490            skip_binder,
491        })
492    }
493
494    /// Push a new binding level corresponding to the provided `def` for the duration of the inner
495    /// function call.
496    pub(crate) fn translate_binder_for_def<F, U>(
497        &mut self,
498        span: Span,
499        kind: BinderKind,
500        def: &hax::FullDef,
501        f: F,
502    ) -> Result<Binder<U>, Error>
503    where
504        F: FnOnce(&mut Self) -> Result<U, Error>,
505    {
506        let inner_hax_state = self.t_ctx.hax_state.clone().with_hax_owner(&def.def_id());
507        let outer_hax_state = mem::replace(&mut self.hax_state, inner_hax_state);
508        let ret = self.inside_binder(kind, |this| {
509            this.push_generics_for_def_without_parents(span, def)?;
510            this.push_late_bound_generics_for_def(span, def)?;
511            this.innermost_binder().params.check_consistency();
512            f(this)
513        });
514        self.hax_state = outer_hax_state;
515        ret
516    }
517
518    /// Push a group of bound regions and call the continuation.
519    /// We use this when diving into a `for<'a>`, or inside an arrow type (because
520    /// it contains universally quantified regions).
521    pub(crate) fn translate_region_binder<F, T, U>(
522        &mut self,
523        _span: Span,
524        binder: &hax::Binder<T>,
525        f: F,
526    ) -> Result<RegionBinder<U>, Error>
527    where
528        F: FnOnce(&mut Self, &T) -> Result<U, Error>,
529    {
530        let binder = self.inside_binder(BinderKind::Other, |this| {
531            this.innermost_binder_mut()
532                .push_params_from_binder(binder.rebind(()))?;
533            f(this, binder.hax_skip_binder_ref())
534        })?;
535        // Convert to a region-only binder.
536        Ok(RegionBinder {
537            regions: binder.params.regions,
538            skip_binder: binder.skip_binder,
539        })
540    }
541
542    pub(crate) fn into_generics(mut self) -> GenericParams {
543        assert!(self.binding_levels.len() == 1);
544        self.binding_levels.pop().unwrap().params
545    }
546}
547
548/// Struct to compute the "mutability" of each lifetime.
549#[derive(Default)]
550pub struct LifetimeMutabilityComputer {
551    lt_mutability: HashMap<hax::DefId, CycleDetector<HashSet<u32>>>,
552}
553
554impl LifetimeMutabilityComputer {
555    /// Compute the mutability of one lifetime.
556    pub(crate) fn compute_lifetime_mutability<'tcx>(
557        &mut self,
558        s: &impl BaseState<'tcx>,
559        item: &hax::DefId,
560        index: u32,
561    ) -> LifetimeMutability {
562        match self.compute_lifetime_mutabilities(s, item) {
563            Some(set) => {
564                if set.contains(&index) {
565                    LifetimeMutability::Mutable
566                } else {
567                    LifetimeMutability::Shared
568                }
569            }
570            None => LifetimeMutability::Unknown,
571        }
572    }
573
574    /// Compute the "mutability" of each lifetime, i.e. whether this lifetime is used in a `&'a mut
575    /// T` type or not. Returns a set of the known-mutable lifetimes for this ADT.
576    fn compute_lifetime_mutabilities<'tcx>(
577        &mut self,
578        s: &impl BaseState<'tcx>,
579        item: &hax::DefId,
580    ) -> Option<&HashSet<u32>> {
581        if !matches!(
582            item.kind,
583            hax::DefKind::Struct | hax::DefKind::Enum | hax::DefKind::Union
584        ) {
585            return None;
586        }
587        if self
588            .lt_mutability
589            .entry(item.clone())
590            .or_default()
591            .start_processing()
592        {
593            use hax::SInto;
594            use ty::{TypeSuperVisitable, TypeVisitable};
595
596            struct LtMutabilityVisitor<'a, S> {
597                s: &'a S,
598                computer: &'a mut LifetimeMutabilityComputer,
599                set: HashSet<u32>,
600            }
601            impl<'tcx, S: BaseState<'tcx>> ty::TypeVisitor<ty::TyCtxt<'tcx>> for LtMutabilityVisitor<'_, S> {
602                fn visit_ty(&mut self, ty: ty::Ty<'tcx>) {
603                    match ty.kind() {
604                        ty::Ref(r, _, ty::Mutability::Mut)
605                            if let ty::RegionKind::ReEarlyParam(r) = r.kind() =>
606                        {
607                            self.set.insert(r.index);
608                        }
609                        ty::Adt(adt, args) => {
610                            let item = adt.did().sinto(self.s);
611                            if let Some(mutabilities) =
612                                self.computer.compute_lifetime_mutabilities(self.s, &item)
613                            {
614                                for arg in args.iter() {
615                                    if let Some(r) = arg.as_region()
616                                        && let ty::RegionKind::ReEarlyParam(r) = r.kind()
617                                        && mutabilities.contains(&r.index)
618                                    {
619                                        self.set.insert(r.index);
620                                    }
621                                }
622                            }
623                        }
624                        _ => {}
625                    }
626                    ty.super_visit_with(self)
627                }
628            }
629            let mut visitor = LtMutabilityVisitor {
630                s,
631                computer: self,
632                set: HashSet::new(),
633            };
634
635            let tcx = s.base().tcx;
636            let def_id = item.real_rust_def_id();
637            let adt_def = tcx.adt_def(def_id);
638            let generics = ty::GenericArgs::identity_for_item(tcx, def_id);
639            for variant in adt_def.variants() {
640                for field in &variant.fields {
641                    field.ty(tcx, generics).visit_with(&mut visitor);
642                }
643            }
644            let set = visitor.set;
645
646            self.lt_mutability
647                .get_mut(item)
648                .unwrap()
649                .done_processing(set);
650        }
651        self.lt_mutability.get(item)?.as_processed()
652    }
653}