charon_driver/translate/
translate_generics.rs

1use crate::translate::translate_predicates::PredicateLocation;
2
3use super::translate_ctx::ItemTransCtx;
4use charon_lib::ast::*;
5use charon_lib::common::hash_by_addr::HashByAddr;
6use hax_frontend_exporter as hax;
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::sync::Arc;
10
11/// A level of binding for type-level variables. Each item has a top-level binding level
12/// corresponding to the parameters and clauses to the items. We may then encounter inner binding
13/// levels in the following cases:
14/// - `for<..>` binders in predicates;
15/// - `fn<..>` function pointer types;
16/// - `dyn Trait` types, represented as `dyn<T: Trait>` (TODO);
17/// - types in a trait declaration or implementation block (TODO);
18/// - methods in a trait declaration or implementation block (TODO).
19///
20/// At each level, we store two things: a `GenericParams` that contains the parameters bound at
21/// this level, and various maps from the rustc-internal indices to our indices.
22#[derive(Debug, Default)]
23pub(crate) struct BindingLevel {
24    /// The parameters and predicates bound at this level.
25    pub params: GenericParams,
26    /// Whether this binder corresponds to an item (method, type) or not (`for<..>` predicate, `fn`
27    /// pointer, etc). This indicates whether it corresponds to a rustc `ParamEnv` and therefore
28    /// whether we should resolve rustc variables there.
29    pub is_item_binder: bool,
30    /// Rust makes the distinction between early and late-bound region parameters. We do not make
31    /// this distinction, and merge early and late bound regions. For details, see:
32    /// <https://smallcultfollowing.com/babysteps/blog/2013/10/29/intermingled-parameter-lists/>
33    /// <https://smallcultfollowing.com/babysteps/blog/2013/11/04/intermingled-parameter-lists/>
34    ///
35    /// The map from rust early regions to translated region indices.
36    pub early_region_vars: std::collections::BTreeMap<hax::EarlyParamRegion, RegionId>,
37    /// The map from rust late/bound regions to translated region indices.
38    pub bound_region_vars: Vec<RegionId>,
39    /// The regions added for by-ref upvars, in order of upvars.
40    pub by_ref_upvar_regions: Vec<RegionId>,
41    /// The map from rust type variable indices to translated type variable indices.
42    pub type_vars_map: HashMap<u32, TypeVarId>,
43    /// The map from rust const generic variables to translate const generic variable indices.
44    pub const_generic_vars_map: HashMap<u32, ConstGenericVarId>,
45    /// Cache the translation of types. This harnesses the deduplication of `TyKind` that hax does.
46    // Important: we can't reuse type caches from earlier binders as the new binder may change what
47    // a given variable resolves to.
48    pub type_trans_cache: HashMap<HashByAddr<Arc<hax::TyKind>>, Ty>,
49}
50
51/// Small helper: we ignore some region names (when they are equal to "'_")
52fn translate_region_name(s: String) -> Option<String> {
53    if s == "'_" { None } else { Some(s) }
54}
55
56impl BindingLevel {
57    pub(crate) fn new(is_item_binder: bool) -> Self {
58        Self {
59            is_item_binder,
60            ..Default::default()
61        }
62    }
63
64    /// Important: we must push all the early-bound regions before pushing any other region.
65    pub(crate) fn push_early_region(&mut self, region: hax::EarlyParamRegion) -> RegionId {
66        let name = translate_region_name(region.name.clone());
67        // Check that there are no late-bound regions
68        assert!(
69            self.bound_region_vars.is_empty(),
70            "Early regions must be translated before late ones"
71        );
72        let rid = self
73            .params
74            .regions
75            .push_with(|index| RegionVar { index, name });
76        self.early_region_vars.insert(region, rid);
77        rid
78    }
79
80    /// Important: we must push all the early-bound regions before pushing any other region.
81    pub(crate) fn push_bound_region(&mut self, region: hax::BoundRegionKind) -> RegionId {
82        use hax::BoundRegionKind::*;
83        let name = match region {
84            Anon => None,
85            NamedAnon(symbol) | Named(_, symbol) => translate_region_name(symbol.clone()),
86            ClosureEnv => Some("@env".to_owned()),
87        };
88        let rid = self
89            .params
90            .regions
91            .push_with(|index| RegionVar { index, name });
92        self.bound_region_vars.push(rid);
93        rid
94    }
95
96    /// Add a region for a by_ref upvar in a closure.
97    pub fn push_upvar_region(&mut self) -> RegionId {
98        // We musn't push to `bound_region_vars` because that will contain the higher-kinded
99        // signature lifetimes if any and they must be lookup-able.
100        let region_id = self
101            .params
102            .regions
103            .push_with(|index| RegionVar { index, name: None });
104        self.by_ref_upvar_regions.push(region_id);
105        region_id
106    }
107
108    pub(crate) fn push_type_var(&mut self, rid: u32, name: String) -> TypeVarId {
109        let var_id = self.params.types.push_with(|index| TypeVar { index, name });
110        self.type_vars_map.insert(rid, var_id);
111        var_id
112    }
113
114    pub(crate) fn push_const_generic_var(&mut self, rid: u32, ty: LiteralTy, name: String) {
115        let var_id = self
116            .params
117            .const_generics
118            .push_with(|index| ConstGenericVar { index, name, ty });
119        self.const_generic_vars_map.insert(rid, var_id);
120    }
121
122    /// Translate a binder of regions by appending the stored reguions to the given vector.
123    pub(crate) fn push_params_from_binder(&mut self, binder: hax::Binder<()>) -> Result<(), Error> {
124        assert!(
125            self.bound_region_vars.is_empty(),
126            "Trying to use two binders at the same binding level"
127        );
128        use hax::BoundVariableKind::*;
129        for p in binder.bound_vars {
130            match p {
131                Region(region) => {
132                    self.push_bound_region(region);
133                }
134                Ty(_) => {
135                    panic!("Unexpected locally bound type variable");
136                }
137                Const => {
138                    panic!("Unexpected locally bound const generic variable");
139                }
140            }
141        }
142        Ok(())
143    }
144}
145
146impl<'tcx, 'ctx> ItemTransCtx<'tcx, 'ctx> {
147    /// Get the only binding level. Panics if there are other binding levels.
148    pub(crate) fn the_only_binder(&self) -> &BindingLevel {
149        assert_eq!(self.binding_levels.len(), 1);
150        self.innermost_binder()
151    }
152
153    /// Get the only binding level. Panics if there are other binding levels.
154    pub(crate) fn the_only_binder_mut(&mut self) -> &mut BindingLevel {
155        assert_eq!(self.binding_levels.len(), 1);
156        self.innermost_binder_mut()
157    }
158
159    pub(crate) fn outermost_binder(&self) -> &BindingLevel {
160        self.binding_levels.outermost()
161    }
162
163    pub(crate) fn innermost_binder(&self) -> &BindingLevel {
164        self.binding_levels.innermost()
165    }
166
167    pub(crate) fn innermost_binder_mut(&mut self) -> &mut BindingLevel {
168        self.binding_levels.innermost_mut()
169    }
170
171    pub(crate) fn innermost_generics_mut(&mut self) -> &mut GenericParams {
172        &mut self.innermost_binder_mut().params
173    }
174
175    pub(crate) fn lookup_bound_region(
176        &mut self,
177        span: Span,
178        dbid: hax::DebruijnIndex,
179        var: hax::BoundVar,
180    ) -> Result<RegionDbVar, Error> {
181        let dbid = DeBruijnId::new(dbid);
182        if let Some(rid) = self
183            .binding_levels
184            .get(dbid)
185            .and_then(|bl| bl.bound_region_vars.get(var))
186        {
187            Ok(DeBruijnVar::bound(dbid, *rid))
188        } else {
189            raise_error!(
190                self,
191                span,
192                "Unexpected error: could not find region '{dbid}_{var}"
193            )
194        }
195    }
196
197    pub(crate) fn lookup_param<Id: Copy>(
198        &mut self,
199        span: Span,
200        f: impl for<'a> Fn(&'a BindingLevel) -> Option<Id>,
201        mk_err: impl FnOnce() -> String,
202    ) -> Result<DeBruijnVar<Id>, Error> {
203        for (dbid, bl) in self.binding_levels.iter_enumerated() {
204            if let Some(id) = f(bl) {
205                return Ok(DeBruijnVar::bound(dbid, id));
206            }
207        }
208        let err = mk_err();
209        raise_error!(self, span, "Unexpected error: could not find {}", err)
210    }
211
212    pub(crate) fn lookup_early_region(
213        &mut self,
214        span: Span,
215        region: &hax::EarlyParamRegion,
216    ) -> Result<RegionDbVar, Error> {
217        self.lookup_param(
218            span,
219            |bl| bl.early_region_vars.get(region).copied(),
220            || format!("the region variable {region:?}"),
221        )
222    }
223
224    pub(crate) fn lookup_type_var(
225        &mut self,
226        span: Span,
227        param: &hax::ParamTy,
228    ) -> Result<TypeDbVar, Error> {
229        self.lookup_param(
230            span,
231            |bl| bl.type_vars_map.get(&param.index).copied(),
232            || format!("the type variable {}", param.name),
233        )
234    }
235
236    pub(crate) fn lookup_const_generic_var(
237        &mut self,
238        span: Span,
239        param: &hax::ParamConst,
240    ) -> Result<ConstGenericDbVar, Error> {
241        self.lookup_param(
242            span,
243            |bl| bl.const_generic_vars_map.get(&param.index).copied(),
244            || format!("the const generic variable {}", param.name),
245        )
246    }
247
248    pub(crate) fn lookup_clause_var(
249        &mut self,
250        span: Span,
251        mut id: usize,
252    ) -> Result<ClauseDbVar, Error> {
253        // The clause indices returned by hax count clauses in order, starting from the parentmost.
254        // While adding clauses to a binding level we already need to translate types and clauses,
255        // so the innermost item binder may not have all the clauses yet. Hence for that binder we
256        // ignore the clause count.
257        let innermost_item_binder_id = self
258            .binding_levels
259            .iter_enumerated()
260            .find(|(_, bl)| bl.is_item_binder)
261            .unwrap()
262            .0;
263        // Iterate over the binders, starting from the outermost.
264        for (dbid, bl) in self.binding_levels.iter_enumerated().rev() {
265            let num_clauses_bound_at_this_level = bl.params.trait_clauses.elem_count();
266            if id < num_clauses_bound_at_this_level || dbid == innermost_item_binder_id {
267                let id = TraitClauseId::from_usize(id);
268                return Ok(DeBruijnVar::bound(dbid, id));
269            } else {
270                id -= num_clauses_bound_at_this_level
271            }
272        }
273        // Actually unreachable
274        raise_error!(
275            self,
276            span,
277            "Unexpected error: could not find clause variable {}",
278            id
279        )
280    }
281
282    pub(crate) fn push_generic_params(&mut self, generics: &hax::TyGenerics) -> Result<(), Error> {
283        for param in &generics.params {
284            self.push_generic_param(param)?;
285        }
286        Ok(())
287    }
288
289    pub(crate) fn push_generic_param(&mut self, param: &hax::GenericParamDef) -> Result<(), Error> {
290        match &param.kind {
291            hax::GenericParamDefKind::Lifetime => {
292                let region = hax::EarlyParamRegion {
293                    index: param.index,
294                    name: param.name.clone(),
295                };
296                let _ = self.innermost_binder_mut().push_early_region(region);
297            }
298            hax::GenericParamDefKind::Type { .. } => {
299                let _ = self
300                    .innermost_binder_mut()
301                    .push_type_var(param.index, param.name.clone());
302            }
303            hax::GenericParamDefKind::Const { ty, .. } => {
304                let span = self.def_span(&param.def_id);
305                // The type should be primitive, meaning it shouldn't contain variables,
306                // non-primitive adts, etc. As a result, we can use an empty context.
307                let ty = self.translate_ty(span, ty)?;
308                match ty.kind().as_literal() {
309                    Some(ty) => self.innermost_binder_mut().push_const_generic_var(
310                        param.index,
311                        *ty,
312                        param.name.clone(),
313                    ),
314                    None => raise_error!(
315                        self,
316                        span,
317                        "Constant parameters of non-literal type are not supported"
318                    ),
319                }
320            }
321        }
322
323        Ok(())
324    }
325
326    /// Add the generics and predicates of this item and its parents to the current context.
327    #[tracing::instrument(skip(self, span))]
328    fn push_generics_for_def(
329        &mut self,
330        span: Span,
331        def: &hax::FullDef,
332        is_parent: bool,
333    ) -> Result<(), Error> {
334        use hax::FullDefKind::*;
335        // Add generics from the parent item, recursively (recursivity is important for closures,
336        // as they can be nested).
337        match def.kind() {
338            AssocTy { .. }
339            | AssocFn { .. }
340            | AssocConst { .. }
341            | Const {
342                kind:
343                    hax::ConstKind::AnonConst
344                    | hax::ConstKind::InlineConst
345                    | hax::ConstKind::PromotedConst,
346                ..
347            }
348            | Closure { .. }
349            | Ctor { .. }
350            | Variant { .. } => {
351                let parent_def_id = def.def_id().parent.as_ref().unwrap();
352                let parent_def = self.hax_def(parent_def_id)?;
353                self.push_generics_for_def(span, &parent_def, true)?;
354            }
355            _ => {}
356        }
357        self.push_generics_for_def_without_parents(span, def, !is_parent, !is_parent)?;
358        Ok(())
359    }
360
361    /// Add the generics and predicates of this item. This does not include the parent generics;
362    /// use `push_generics_for_def` to get the full list.
363    fn push_generics_for_def_without_parents(
364        &mut self,
365        span: Span,
366        def: &hax::FullDef,
367        include_late_bound: bool,
368        include_assoc_ty_clauses: bool,
369    ) -> Result<(), Error> {
370        use hax::FullDefKind;
371        if let Some(param_env) = def.param_env() {
372            // Add the generic params.
373            self.push_generic_params(&param_env.generics)?;
374            // Add the predicates.
375            let origin = match &def.kind {
376                FullDefKind::Adt { .. }
377                | FullDefKind::TyAlias { .. }
378                | FullDefKind::AssocTy { .. } => PredicateOrigin::WhereClauseOnType,
379                FullDefKind::Fn { .. }
380                | FullDefKind::AssocFn { .. }
381                | FullDefKind::Const { .. }
382                | FullDefKind::AssocConst { .. }
383                | FullDefKind::Static { .. } => PredicateOrigin::WhereClauseOnFn,
384                FullDefKind::TraitImpl { .. } | FullDefKind::InherentImpl { .. } => {
385                    PredicateOrigin::WhereClauseOnImpl
386                }
387                FullDefKind::Trait { .. } | FullDefKind::TraitAlias { .. } => {
388                    let _ = self.register_trait_decl_id(span, &def.def_id);
389                    PredicateOrigin::WhereClauseOnTrait
390                }
391                _ => panic!("Unexpected def: {def:?}"),
392            };
393            self.register_predicates(
394                &param_env.predicates,
395                origin.clone(),
396                &PredicateLocation::Base,
397            )?;
398            // Also register implied predicates.
399            if let FullDefKind::Trait {
400                implied_predicates, ..
401            }
402            | FullDefKind::TraitAlias {
403                implied_predicates, ..
404            }
405            | FullDefKind::AssocTy {
406                implied_predicates, ..
407            } = &def.kind
408            {
409                self.register_predicates(implied_predicates, origin, &PredicateLocation::Parent)?;
410            }
411
412            if let hax::FullDefKind::Trait { items, .. } = &def.kind
413                && include_assoc_ty_clauses
414            {
415                // Also add the predicates on associated types.
416                // FIXME(gat): don't skip GATs.
417                // FIXME: don't mix up implied and required predicates.
418                for (_item, item_def) in items {
419                    if let hax::FullDefKind::AssocTy {
420                        param_env,
421                        implied_predicates,
422                        ..
423                    } = &item_def.kind
424                        && param_env.generics.params.is_empty()
425                    {
426                        let name = self.t_ctx.translate_trait_item_name(item_def.def_id())?;
427                        self.register_predicates(
428                            &implied_predicates,
429                            PredicateOrigin::TraitItem(name.clone()),
430                            &PredicateLocation::Item(name),
431                        )?;
432                    }
433                }
434            }
435        }
436
437        if let hax::FullDefKind::Closure { args, .. } = def.kind()
438            && include_late_bound
439        {
440            // Add the lifetime generics coming from the by-ref upvars.
441            args.upvar_tys.iter().for_each(|ty| {
442                if matches!(
443                    ty.kind(),
444                    hax::TyKind::Ref(
445                        hax::Region {
446                            kind: hax::RegionKind::ReErased
447                        },
448                        ..
449                    )
450                ) {
451                    self.the_only_binder_mut().push_upvar_region();
452                }
453            });
454        }
455
456        // The parameters (and in particular the lifetimes) are split between
457        // early bound and late bound parameters. See those blog posts for explanations:
458        // https://smallcultfollowing.com/babysteps/blog/2013/10/29/intermingled-parameter-lists/
459        // https://smallcultfollowing.com/babysteps/blog/2013/11/04/intermingled-parameter-lists/
460        // Note that only lifetimes can be late bound.
461        //
462        // [TyCtxt.generics_of] gives us the early-bound parameters. We add the late-bound
463        // parameters here.
464        let signature = match &def.kind {
465            hax::FullDefKind::Fn { sig, .. } => Some(sig),
466            hax::FullDefKind::AssocFn { sig, .. } => Some(sig),
467            _ => None,
468        };
469        if let Some(signature) = signature
470            && include_late_bound
471        {
472            let innermost_binder = self.innermost_binder_mut();
473            assert!(innermost_binder.bound_region_vars.is_empty());
474            innermost_binder.push_params_from_binder(signature.rebind(()))?;
475        }
476
477        Ok(())
478    }
479
480    /// Translate the generics and predicates of this item and its parents.
481    pub(crate) fn translate_def_generics(
482        &mut self,
483        span: Span,
484        def: &hax::FullDef,
485    ) -> Result<(), Error> {
486        assert!(self.binding_levels.len() == 0);
487        self.binding_levels.push(BindingLevel::new(true));
488        self.push_generics_for_def(span, def, false)?;
489        self.innermost_binder_mut().params.check_consistency();
490        Ok(())
491    }
492
493    /// Translate the generics and predicates of this item without its parents.
494    pub(crate) fn translate_def_generics_without_parents(
495        &mut self,
496        span: Span,
497        def: &hax::FullDef,
498    ) -> Result<(), Error> {
499        self.binding_levels.push(BindingLevel::new(true));
500        self.push_generics_for_def_without_parents(span, def, true, true)?;
501        self.innermost_binder().params.check_consistency();
502        Ok(())
503    }
504
505    /// Push a new binding level corresponding to the provided `def` for the duration of the inner
506    /// function call.
507    pub(crate) fn translate_binder_for_def<F, U>(
508        &mut self,
509        span: Span,
510        kind: BinderKind,
511        def: &hax::FullDef,
512        f: F,
513    ) -> Result<Binder<U>, Error>
514    where
515        F: FnOnce(&mut Self) -> Result<U, Error>,
516    {
517        assert!(!self.binding_levels.is_empty());
518
519        // Register the type-level parameters. This pushes a new binding level.
520        self.translate_def_generics_without_parents(span, def)?;
521
522        // Call the continuation. Important: do not short-circuit on error here.
523        let res = f(self);
524
525        // Reset
526        let params = self.binding_levels.pop().unwrap().params;
527
528        // Return
529        res.map(|skip_binder| Binder {
530            kind,
531            params,
532            skip_binder,
533        })
534    }
535
536    /// Push a group of bound regions and call the continuation.
537    /// We use this when diving into a `for<'a>`, or inside an arrow type (because
538    /// it contains universally quantified regions).
539    pub(crate) fn translate_region_binder<F, T, U>(
540        &mut self,
541        _span: Span,
542        binder: &hax::Binder<T>,
543        f: F,
544    ) -> Result<RegionBinder<U>, Error>
545    where
546        F: FnOnce(&mut Self, &T) -> Result<U, Error>,
547    {
548        assert!(!self.binding_levels.is_empty());
549
550        // Register the variables
551        let mut binding_level = BindingLevel::new(false);
552        binding_level.push_params_from_binder(binder.rebind(()))?;
553        self.binding_levels.push(binding_level);
554
555        // Call the continuation. Important: do not short-circuit on error here.
556        let res = f(self, binder.hax_skip_binder_ref());
557
558        // Reset
559        let regions = self.binding_levels.pop().unwrap().params.regions;
560
561        // Return
562        res.map(|skip_binder| RegionBinder {
563            regions,
564            skip_binder,
565        })
566    }
567
568    pub(crate) fn into_generics(mut self) -> GenericParams {
569        assert!(self.binding_levels.len() == 1);
570        self.binding_levels.pop().unwrap().params
571    }
572}