charon_lib/ast/
types_utils.rs

1//! This file groups everything which is linked to implementations about [crate::types]
2use crate::ast::*;
3use crate::formatter::FmtCtx;
4use crate::ids::Vector;
5use crate::pretty::FmtWithCtx;
6use derive_generic_visitor::*;
7use std::collections::HashSet;
8use std::convert::Infallible;
9use std::fmt::Debug;
10use std::iter::Iterator;
11use std::mem;
12use std::ops::Index;
13
14impl TraitClause {
15    /// Constructs the trait ref that refers to this clause.
16    pub fn identity_tref(&self) -> TraitRef {
17        self.identity_tref_at_depth(DeBruijnId::zero())
18    }
19
20    /// Like `identity_tref` but uses variables bound at the given depth.
21    pub fn identity_tref_at_depth(&self, depth: DeBruijnId) -> TraitRef {
22        TraitRef {
23            kind: TraitRefKind::Clause(DeBruijnVar::bound(depth, self.clause_id)),
24            trait_decl_ref: self.trait_.clone().move_under_binders(depth),
25        }
26    }
27}
28
29impl GenericParams {
30    pub fn empty() -> Self {
31        Self::default()
32    }
33
34    pub fn is_empty(&self) -> bool {
35        self.len() == 0
36    }
37    /// Whether this has any explicit arguments (types, regions or const generics).
38    pub fn has_explicits(&self) -> bool {
39        !self.regions.is_empty() || !self.types.is_empty() || !self.const_generics.is_empty()
40    }
41    /// Whether this has any implicit arguments (trait clauses, outlives relations, associated type
42    /// equality constraints).
43    pub fn has_predicates(&self) -> bool {
44        !self.trait_clauses.is_empty()
45            || !self.types_outlive.is_empty()
46            || !self.regions_outlive.is_empty()
47            || !self.trait_type_constraints.is_empty()
48    }
49
50    /// Run some sanity checks.
51    pub fn check_consistency(&self) {
52        // Sanity check: check the clause ids are consistent.
53        assert!(self
54            .trait_clauses
55            .iter()
56            .enumerate()
57            .all(|(i, c)| c.clause_id.index() == i));
58
59        // Sanity check: region names are pairwise distinct (this caused trouble when generating
60        // names for the backward functions in Aeneas): at some point, Rustc introduced names equal
61        // to `Some("'_")` for the anonymous regions, instead of using `None` (we now check in
62        // [translate_region_name] and ignore names equal to "'_").
63        let mut s = HashSet::new();
64        for r in &self.regions {
65            if let Some(name) = &r.name {
66                assert!(
67                    !s.contains(name),
68                    "Name \"{}\" reused for two different lifetimes",
69                    name
70                );
71                s.insert(name);
72            }
73        }
74    }
75
76    pub fn len(&self) -> usize {
77        let GenericParams {
78            regions,
79            types,
80            const_generics,
81            trait_clauses,
82            regions_outlive,
83            types_outlive,
84            trait_type_constraints,
85        } = self;
86        regions.elem_count()
87            + types.elem_count()
88            + const_generics.elem_count()
89            + trait_clauses.elem_count()
90            + regions_outlive.len()
91            + types_outlive.len()
92            + trait_type_constraints.elem_count()
93    }
94
95    /// Construct a set of generic arguments in the scope of `self` that matches `self` and feeds
96    /// each required parameter with itself. E.g. given parameters for `<T, U> where U:
97    /// PartialEq<T>`, the arguments would be `<T, U>[@TraitClause0]`.
98    pub fn identity_args(&self, target: GenericsSource) -> GenericArgs {
99        self.identity_args_at_depth(target, DeBruijnId::zero())
100    }
101
102    /// Like `identity_args` but uses variables bound at the given depth.
103    pub fn identity_args_at_depth(&self, target: GenericsSource, depth: DeBruijnId) -> GenericArgs {
104        GenericArgs {
105            regions: self
106                .regions
107                .map_ref_indexed(|id, _| Region::Var(DeBruijnVar::bound(depth, id))),
108            types: self
109                .types
110                .map_ref_indexed(|id, _| TyKind::TypeVar(DeBruijnVar::bound(depth, id)).into_ty()),
111            const_generics: self
112                .const_generics
113                .map_ref_indexed(|id, _| ConstGeneric::Var(DeBruijnVar::bound(depth, id))),
114            trait_refs: self
115                .trait_clauses
116                .map_ref(|clause| clause.identity_tref_at_depth(depth)),
117            target,
118        }
119    }
120}
121
122impl<T> Binder<T> {
123    pub fn new(kind: BinderKind, params: GenericParams, skip_binder: T) -> Self {
124        Self {
125            params,
126            skip_binder,
127            kind,
128        }
129    }
130
131    /// Substitute the provided arguments for the variables bound in this binder and return the
132    /// substituted inner value.
133    pub fn apply(self, args: &GenericArgs) -> T
134    where
135        T: TyVisitable,
136    {
137        self.skip_binder.substitute(args)
138    }
139}
140
141impl<T: AstVisitable> Binder<Binder<T>> {
142    /// Flatten two levels of binders into a single one.
143    pub fn flatten(self) -> Binder<T> {
144        #[derive(Visitor)]
145        struct FlattenVisitor<'a> {
146            shift_by: &'a GenericParams,
147            binder_depth: DeBruijnId,
148        }
149        impl VisitAstMut for FlattenVisitor<'_> {
150            fn enter_region_binder<T: AstVisitable>(&mut self, _: &mut RegionBinder<T>) {
151                self.binder_depth = self.binder_depth.incr()
152            }
153            fn exit_region_binder<T: AstVisitable>(&mut self, _: &mut RegionBinder<T>) {
154                self.binder_depth = self.binder_depth.decr()
155            }
156            fn enter_binder<T: AstVisitable>(&mut self, _: &mut Binder<T>) {
157                self.binder_depth = self.binder_depth.incr()
158            }
159            fn exit_binder<T: AstVisitable>(&mut self, _: &mut Binder<T>) {
160                self.binder_depth = self.binder_depth.decr()
161            }
162            fn enter_de_bruijn_id(&mut self, db_id: &mut DeBruijnId) {
163                if *db_id > self.binder_depth {
164                    // We started visiting at the inner binder, so in this branch we're either
165                    // mentioning the outer binder or a binder further beyond. Either way we
166                    // decrease the depth; variables that point to the outer binder don't have to
167                    // be shifted.
168                    *db_id = db_id.decr();
169                }
170            }
171            fn enter_region(&mut self, x: &mut Region) {
172                if let Region::Var(var) = x
173                    && let Some(id) = var.bound_at_depth_mut(self.binder_depth)
174                {
175                    *id += self.shift_by.regions.slot_count();
176                }
177            }
178            fn enter_ty_kind(&mut self, x: &mut TyKind) {
179                if let TyKind::TypeVar(var) = x
180                    && let Some(id) = var.bound_at_depth_mut(self.binder_depth)
181                {
182                    *id += self.shift_by.types.slot_count();
183                }
184            }
185            fn enter_const_generic(&mut self, x: &mut ConstGeneric) {
186                if let ConstGeneric::Var(var) = x
187                    && let Some(id) = var.bound_at_depth_mut(self.binder_depth)
188                {
189                    *id += self.shift_by.const_generics.slot_count();
190                }
191            }
192            fn enter_trait_ref_kind(&mut self, x: &mut TraitRefKind) {
193                if let TraitRefKind::Clause(var) = x
194                    && let Some(id) = var.bound_at_depth_mut(self.binder_depth)
195                {
196                    *id += self.shift_by.trait_clauses.slot_count();
197                }
198            }
199        }
200
201        // We will concatenate both sets of params.
202        let mut outer_params = self.params;
203
204        // The inner value needs to change:
205        // - at binder level 0 we shift all variable ids to match the concatenated params;
206        // - at binder level > 0 we decrease binding level because there's one fewer binder.
207        let mut bound_value = self.skip_binder.skip_binder;
208        let _ = bound_value.drive_mut(&mut FlattenVisitor {
209            shift_by: &outer_params,
210            binder_depth: Default::default(),
211        });
212
213        // The inner params must also be updated, as they can refer to themselves and the outer
214        // one.
215        let mut inner_params = self.skip_binder.params;
216        let _ = inner_params.drive_mut(&mut FlattenVisitor {
217            shift_by: &outer_params,
218            binder_depth: Default::default(),
219        });
220        inner_params
221            .regions
222            .iter_mut()
223            .for_each(|v| v.index += outer_params.regions.slot_count());
224        inner_params
225            .types
226            .iter_mut()
227            .for_each(|v| v.index += outer_params.types.slot_count());
228        inner_params
229            .const_generics
230            .iter_mut()
231            .for_each(|v| v.index += outer_params.const_generics.slot_count());
232        inner_params
233            .trait_clauses
234            .iter_mut()
235            .for_each(|v| v.clause_id += outer_params.trait_clauses.slot_count());
236
237        let GenericParams {
238            regions,
239            types,
240            const_generics,
241            trait_clauses,
242            regions_outlive,
243            types_outlive,
244            trait_type_constraints,
245        } = &inner_params;
246        outer_params.regions.extend_from_slice(regions);
247        outer_params.types.extend_from_slice(types);
248        outer_params
249            .const_generics
250            .extend_from_slice(const_generics);
251        outer_params.trait_clauses.extend_from_slice(trait_clauses);
252        outer_params
253            .regions_outlive
254            .extend_from_slice(regions_outlive);
255        outer_params.types_outlive.extend_from_slice(types_outlive);
256        outer_params
257            .trait_type_constraints
258            .extend_from_slice(trait_type_constraints);
259
260        Binder {
261            params: outer_params,
262            skip_binder: bound_value,
263            kind: BinderKind::Other,
264        }
265    }
266}
267
268impl<T> RegionBinder<T> {
269    /// Wrap the value in an empty region binder, shifting variables appropriately.
270    pub fn empty(x: T) -> Self
271    where
272        T: TyVisitable,
273    {
274        RegionBinder {
275            regions: Default::default(),
276            skip_binder: x.move_under_binder(),
277        }
278    }
279
280    pub fn map_ref<U>(&self, f: impl FnOnce(&T) -> U) -> RegionBinder<U> {
281        RegionBinder {
282            regions: self.regions.clone(),
283            skip_binder: f(&self.skip_binder),
284        }
285    }
286
287    /// Substitute the bound variables with erased lifetimes.
288    pub fn erase(self) -> T
289    where
290        T: AstVisitable,
291    {
292        let args = GenericArgs {
293            regions: self.regions.map_ref_indexed(|_, _| Region::Erased),
294            ..GenericArgs::empty(GenericsSource::Other)
295        };
296        self.skip_binder.substitute(&args)
297    }
298}
299
300impl GenericArgs {
301    pub fn len(&self) -> usize {
302        let GenericArgs {
303            regions,
304            types,
305            const_generics,
306            trait_refs,
307            target: _,
308        } = self;
309        regions.elem_count()
310            + types.elem_count()
311            + const_generics.elem_count()
312            + trait_refs.elem_count()
313    }
314
315    pub fn is_empty(&self) -> bool {
316        self.len() == 0
317    }
318    /// Whether this has any explicit arguments (types, regions or const generics).
319    pub fn has_explicits(&self) -> bool {
320        !self.regions.is_empty() || !self.types.is_empty() || !self.const_generics.is_empty()
321    }
322    /// Whether this has any implicit arguments (trait refs).
323    pub fn has_implicits(&self) -> bool {
324        !self.trait_refs.is_empty()
325    }
326
327    pub fn empty(target: GenericsSource) -> Self {
328        GenericArgs {
329            regions: Default::default(),
330            types: Default::default(),
331            const_generics: Default::default(),
332            trait_refs: Default::default(),
333            target,
334        }
335    }
336
337    pub fn new_for_builtin(types: Vector<TypeVarId, Ty>) -> Self {
338        GenericArgs {
339            types,
340            ..Self::empty(GenericsSource::Builtin)
341        }
342    }
343
344    pub fn new(
345        regions: Vector<RegionId, Region>,
346        types: Vector<TypeVarId, Ty>,
347        const_generics: Vector<ConstGenericVarId, ConstGeneric>,
348        trait_refs: Vector<TraitClauseId, TraitRef>,
349        target: GenericsSource,
350    ) -> Self {
351        Self {
352            regions,
353            types,
354            const_generics,
355            trait_refs,
356            target,
357        }
358    }
359
360    pub fn new_types(types: Vector<TypeVarId, Ty>, target: GenericsSource) -> Self {
361        Self {
362            types,
363            ..Self::empty(target)
364        }
365    }
366
367    /// Changes the target.
368    pub fn with_target(mut self, target: GenericsSource) -> Self {
369        self.target = target;
370        self
371    }
372
373    /// Check whether this matches the given `GenericParams`.
374    /// TODO: check more things, e.g. that the trait refs use the correct trait and generics.
375    pub fn matches(&self, params: &GenericParams) -> bool {
376        params.regions.elem_count() == self.regions.elem_count()
377            && params.types.elem_count() == self.types.elem_count()
378            && params.const_generics.elem_count() == self.const_generics.elem_count()
379            && params.trait_clauses.elem_count() == self.trait_refs.elem_count()
380    }
381
382    /// Return the same generics, but where we pop the first type arguments.
383    /// This is useful for trait references (for pretty printing for instance),
384    /// because the first type argument is the type for which the trait is
385    /// implemented.
386    pub fn pop_first_type_arg(&self) -> (Ty, Self) {
387        let mut generics = self.clone();
388        let mut it = mem::take(&mut generics.types).into_iter();
389        let ty = it.next().unwrap();
390        generics.types = it.collect();
391        (ty, generics)
392    }
393
394    /// Concatenate this set of arguments with another one. Use with care, you must manage the
395    /// order of arguments correctly.
396    pub fn concat(mut self, target: GenericsSource, other: &Self) -> Self {
397        let Self {
398            regions,
399            types,
400            const_generics,
401            trait_refs,
402            target: _,
403        } = other;
404        self.regions.extend_from_slice(regions);
405        self.types.extend_from_slice(types);
406        self.const_generics.extend_from_slice(const_generics);
407        self.trait_refs.extend_from_slice(trait_refs);
408        self.target = target;
409        self
410    }
411}
412
413impl GenericsSource {
414    pub fn item<I: Into<AnyTransId>>(id: I) -> Self {
415        Self::Item(id.into())
416    }
417
418    /// Return a path that represents the target item.
419    pub fn item_name(&self, translated: &TranslatedCrate, fmt_ctx: &FmtCtx) -> String {
420        match self {
421            GenericsSource::Item(id) => translated
422                .item_name(*id)
423                .unwrap()
424                .to_string_with_ctx(fmt_ctx),
425            GenericsSource::Method(trait_id, method_name) => format!(
426                "{}::{method_name}",
427                translated
428                    .item_name(*trait_id)
429                    .unwrap()
430                    .to_string_with_ctx(fmt_ctx),
431            ),
432            GenericsSource::Builtin => format!("<built-in>"),
433            GenericsSource::Other => format!("<unknown>"),
434        }
435    }
436}
437
438impl IntegerTy {
439    pub fn is_signed(&self) -> bool {
440        matches!(
441            self,
442            IntegerTy::Isize
443                | IntegerTy::I8
444                | IntegerTy::I16
445                | IntegerTy::I32
446                | IntegerTy::I64
447                | IntegerTy::I128
448        )
449    }
450
451    pub fn is_unsigned(&self) -> bool {
452        !(self.is_signed())
453    }
454
455    /// Return the size (in bytes) of an integer of the proper type
456    pub fn size(&self) -> usize {
457        use std::mem::size_of;
458        match self {
459            IntegerTy::Isize => size_of::<isize>(),
460            IntegerTy::I8 => size_of::<i8>(),
461            IntegerTy::I16 => size_of::<i16>(),
462            IntegerTy::I32 => size_of::<i32>(),
463            IntegerTy::I64 => size_of::<i64>(),
464            IntegerTy::I128 => size_of::<i128>(),
465            IntegerTy::Usize => size_of::<isize>(),
466            IntegerTy::U8 => size_of::<u8>(),
467            IntegerTy::U16 => size_of::<u16>(),
468            IntegerTy::U32 => size_of::<u32>(),
469            IntegerTy::U64 => size_of::<u64>(),
470            IntegerTy::U128 => size_of::<u128>(),
471        }
472    }
473}
474
475/// A value of type `T` bound by the generic parameters of item
476/// `item`. Used when dealing with multiple items at a time, to
477/// ensure we don't mix up generics.
478///
479/// To get the value, use `under_binder_of` or `subst_for`.
480#[derive(Debug, Clone, Copy)]
481pub struct ItemBinder<ItemId, T> {
482    pub item_id: ItemId,
483    val: T,
484}
485
486impl<ItemId, T> ItemBinder<ItemId, T>
487where
488    ItemId: Debug + Copy + PartialEq,
489{
490    pub fn new(item_id: ItemId, val: T) -> Self {
491        Self { item_id, val }
492    }
493
494    pub fn as_ref(&self) -> ItemBinder<ItemId, &T> {
495        ItemBinder {
496            item_id: self.item_id,
497            val: &self.val,
498        }
499    }
500
501    pub fn map_bound<U>(self, f: impl FnOnce(T) -> U) -> ItemBinder<ItemId, U> {
502        ItemBinder {
503            item_id: self.item_id,
504            val: f(self.val),
505        }
506    }
507
508    fn assert_item_id(&self, item_id: ItemId) {
509        assert_eq!(
510            self.item_id, item_id,
511            "Trying to use item bound for {:?} as if it belonged to {:?}",
512            self.item_id, item_id
513        );
514    }
515
516    /// Assert that the value is bound for item `item_id`, and returns it. This is used when we
517    /// plan to store the returned value inside that item.
518    pub fn under_binder_of(self, item_id: ItemId) -> T {
519        self.assert_item_id(item_id);
520        self.val
521    }
522
523    /// Given generic args for `item_id`, assert that the value is bound for `item_id` and
524    /// substitute it with the provided generic arguments. Because the arguments are bound in the
525    /// context of another item, so it the resulting substituted value.
526    pub fn substitute<OtherItem: Debug + Copy + PartialEq>(
527        self,
528        args: ItemBinder<OtherItem, &GenericArgs>,
529    ) -> ItemBinder<OtherItem, T>
530    where
531        ItemId: Into<AnyTransId>,
532        T: TyVisitable,
533    {
534        args.map_bound(|args| {
535            assert_eq!(
536                args.target,
537                GenericsSource::item(self.item_id),
538                "These `GenericArgs` are meant for {:?} but were used on {:?}",
539                args.target,
540                self.item_id
541            );
542            self.val.substitute(args)
543        })
544    }
545}
546
547/// Dummy item identifier that represents the current item when not ambiguous.
548#[derive(Debug, Clone, Copy, PartialEq, Eq)]
549pub struct CurrentItem;
550
551impl<T> ItemBinder<CurrentItem, T> {
552    pub fn under_current_binder(self) -> T {
553        self.val
554    }
555}
556
557impl Ty {
558    /// Return true if it is actually unit (i.e.: 0-tuple)
559    pub fn is_unit(&self) -> bool {
560        match self.kind() {
561            TyKind::Adt(TypeId::Tuple, args) => {
562                assert!(args.regions.is_empty());
563                assert!(args.const_generics.is_empty());
564                args.types.is_empty()
565            }
566            _ => false,
567        }
568    }
569
570    /// Return the unit type
571    pub fn mk_unit() -> Ty {
572        Self::mk_tuple(vec![])
573    }
574
575    pub fn mk_tuple(tys: Vec<Ty>) -> Ty {
576        TyKind::Adt(
577            TypeId::Tuple,
578            GenericArgs::new(
579                Vector::new(),
580                tys.into(),
581                Vector::new(),
582                Vector::new(),
583                GenericsSource::Builtin,
584            ),
585        )
586        .into_ty()
587    }
588
589    /// Return true if this is a scalar type
590    pub fn is_scalar(&self) -> bool {
591        match self.kind() {
592            TyKind::Literal(kind) => kind.is_integer(),
593            _ => false,
594        }
595    }
596
597    pub fn is_unsigned_scalar(&self) -> bool {
598        match self.kind() {
599            TyKind::Literal(LiteralTy::Integer(kind)) => kind.is_unsigned(),
600            _ => false,
601        }
602    }
603
604    pub fn is_signed_scalar(&self) -> bool {
605        match self.kind() {
606            TyKind::Literal(LiteralTy::Integer(kind)) => kind.is_signed(),
607            _ => false,
608        }
609    }
610
611    /// Return true if the type is Box
612    pub fn is_box(&self) -> bool {
613        match self.kind() {
614            TyKind::Adt(TypeId::Builtin(BuiltinTy::Box), generics) => {
615                assert!(generics.regions.is_empty());
616                assert!(generics.types.elem_count() == 1);
617                assert!(generics.const_generics.is_empty());
618                true
619            }
620            _ => false,
621        }
622    }
623
624    pub fn as_box(&self) -> Option<&Ty> {
625        match self.kind() {
626            TyKind::Adt(TypeId::Builtin(BuiltinTy::Box), generics) => {
627                assert!(generics.regions.is_empty());
628                assert!(generics.types.elem_count() == 1);
629                assert!(generics.const_generics.is_empty());
630                Some(&generics.types[0])
631            }
632            _ => None,
633        }
634    }
635
636    pub fn as_array_or_slice(&self) -> Option<&Ty> {
637        match self.kind() {
638            TyKind::Adt(TypeId::Builtin(BuiltinTy::Array | BuiltinTy::Slice), generics) => {
639                assert!(generics.regions.is_empty());
640                assert!(generics.types.elem_count() == 1);
641                Some(&generics.types[0])
642            }
643            _ => None,
644        }
645    }
646
647    pub fn as_tuple(&self) -> Option<&Vector<TypeVarId, Ty>> {
648        match self.kind() {
649            TyKind::Adt(TypeId::Tuple, generics) => {
650                assert!(generics.regions.is_empty());
651                assert!(generics.const_generics.is_empty());
652                Some(&generics.types)
653            }
654            _ => None,
655        }
656    }
657
658    pub fn as_adt(&self) -> Option<(TypeId, &GenericArgs)> {
659        match self.kind() {
660            TyKind::Adt(id, generics) => Some((*id, generics)),
661            _ => None,
662        }
663    }
664}
665
666impl TyKind {
667    pub fn into_ty(self) -> Ty {
668        Ty::new(self)
669    }
670}
671
672impl From<TyKind> for Ty {
673    fn from(kind: TyKind) -> Ty {
674        kind.into_ty()
675    }
676}
677
678/// Convenience for migration purposes.
679impl std::ops::Deref for Ty {
680    type Target = TyKind;
681
682    fn deref(&self) -> &Self::Target {
683        self.kind()
684    }
685}
686/// For deref patterns.
687unsafe impl std::ops::DerefPure for Ty {}
688
689impl TypeId {
690    pub fn generics_target(&self) -> GenericsSource {
691        match *self {
692            TypeId::Adt(decl_id) => GenericsSource::item(decl_id),
693            TypeId::Tuple | TypeId::Builtin(..) => GenericsSource::Builtin,
694        }
695    }
696}
697
698impl FunId {
699    pub fn generics_target(&self) -> GenericsSource {
700        match *self {
701            FunId::Regular(fun_id) => GenericsSource::item(fun_id),
702            FunId::Builtin(..) => GenericsSource::Builtin,
703        }
704    }
705}
706
707impl FunIdOrTraitMethodRef {
708    pub fn generics_target(&self) -> GenericsSource {
709        match self {
710            FunIdOrTraitMethodRef::Fun(fun_id) => fun_id.generics_target(),
711            FunIdOrTraitMethodRef::Trait(trait_ref, name, _) => {
712                GenericsSource::Method(trait_ref.trait_decl_ref.skip_binder.trait_id, name.clone())
713            }
714        }
715    }
716}
717
718impl Field {
719    /// The new name for this field, as suggested by the `#[charon::rename]` attribute.
720    pub fn renamed_name(&self) -> Option<&str> {
721        self.attr_info.rename.as_deref().or(self.name.as_deref())
722    }
723
724    /// Whether this field has a `#[charon::opaque]` annotation.
725    pub fn is_opaque(&self) -> bool {
726        self.attr_info
727            .attributes
728            .iter()
729            .any(|attr| attr.is_opaque())
730    }
731}
732
733impl Variant {
734    /// The new name for this variant, as suggested by the `#[charon::rename]` and
735    /// `#[charon::variants_prefix]` attributes.
736    pub fn renamed_name(&self) -> &str {
737        self.attr_info
738            .rename
739            .as_deref()
740            .unwrap_or(self.name.as_ref())
741    }
742
743    /// Whether this variant has a `#[charon::opaque]` annotation.
744    pub fn is_opaque(&self) -> bool {
745        self.attr_info
746            .attributes
747            .iter()
748            .any(|attr| attr.is_opaque())
749    }
750}
751
752impl RefKind {
753    pub fn mutable(x: bool) -> Self {
754        if x {
755            Self::Mut
756        } else {
757            Self::Shared
758        }
759    }
760}
761
762/// Visitor for the [TyVisitable::substitute] function.
763/// This substitutes variables bound at the level where we start to substitute (level 0).
764#[derive(Visitor)]
765pub(crate) struct SubstVisitor<'a> {
766    generics: &'a GenericArgs,
767    self_ref: &'a TraitRefKind,
768    // Tracks the depth of binders we're inside of.
769    // Important: we must update it whenever we go inside a binder.
770    binder_depth: DeBruijnId,
771}
772
773impl<'a> SubstVisitor<'a> {
774    pub(crate) fn new(generics: &'a GenericArgs, self_ref: &'a TraitRefKind) -> Self {
775        Self {
776            generics,
777            self_ref,
778            binder_depth: DeBruijnId::zero(),
779        }
780    }
781
782    /// Process the variable, either modifying the variable in-place or returning the new value to
783    /// assign to the type/region/const generic/trait ref that was this variable.
784    fn process_var<Id, T>(&self, var: &mut DeBruijnVar<Id>) -> Option<T>
785    where
786        Id: Copy,
787        GenericArgs: Index<Id, Output = T>,
788        T: Clone + TyVisitable,
789    {
790        use std::cmp::Ordering::*;
791        match var {
792            DeBruijnVar::Bound(dbid, varid) => match (*dbid).cmp(&self.binder_depth) {
793                Equal => Some(
794                    self.generics[*varid]
795                        .clone()
796                        .move_under_binders(self.binder_depth),
797                ),
798                Greater => {
799                    // This is bound outside the binder we're substituting for.
800                    *dbid = dbid.decr();
801                    None
802                }
803                Less => None,
804            },
805            DeBruijnVar::Free(..) => None,
806        }
807    }
808}
809
810impl VisitAstMut for SubstVisitor<'_> {
811    fn enter_region_binder<T: AstVisitable>(&mut self, _: &mut RegionBinder<T>) {
812        self.binder_depth = self.binder_depth.incr()
813    }
814    fn exit_region_binder<T: AstVisitable>(&mut self, _: &mut RegionBinder<T>) {
815        self.binder_depth = self.binder_depth.decr()
816    }
817    fn enter_binder<T: AstVisitable>(&mut self, _: &mut Binder<T>) {
818        self.binder_depth = self.binder_depth.incr()
819    }
820    fn exit_binder<T: AstVisitable>(&mut self, _: &mut Binder<T>) {
821        self.binder_depth = self.binder_depth.decr()
822    }
823
824    fn exit_region(&mut self, r: &mut Region) {
825        match r {
826            Region::Var(var) => {
827                if let Some(new_r) = self.process_var(var) {
828                    *r = new_r;
829                }
830            }
831            _ => (),
832        }
833    }
834
835    fn exit_ty(&mut self, ty: &mut Ty) {
836        let new_ty = ty.with_kind_mut(|kind| match kind {
837            TyKind::TypeVar(var) => self.process_var(var),
838            _ => None,
839        });
840        if let Some(new_ty) = new_ty {
841            *ty = new_ty
842        }
843    }
844
845    fn exit_const_generic(&mut self, cg: &mut ConstGeneric) {
846        match cg {
847            ConstGeneric::Var(var) => {
848                if let Some(new_cg) = self.process_var(var) {
849                    *cg = new_cg;
850                }
851            }
852            _ => (),
853        }
854    }
855
856    fn exit_constant_expr(&mut self, ce: &mut ConstantExpr) {
857        match &mut ce.value {
858            RawConstantExpr::Var(var) => {
859                if let Some(new_ce) = self.process_var(var) {
860                    ce.value = match new_ce {
861                        ConstGeneric::Global(id) => RawConstantExpr::Global(GlobalDeclRef {
862                            id,
863                            generics: Box::new(GenericArgs::empty(GenericsSource::item(id))),
864                        }),
865                        ConstGeneric::Var(var) => RawConstantExpr::Var(var),
866                        ConstGeneric::Value(lit) => RawConstantExpr::Literal(lit),
867                    };
868                }
869            }
870            _ => (),
871        }
872    }
873
874    fn exit_trait_ref_kind(&mut self, kind: &mut TraitRefKind) {
875        match kind {
876            TraitRefKind::SelfId => {
877                *kind = self.self_ref.clone().move_under_binders(self.binder_depth);
878            }
879            TraitRefKind::Clause(var) => {
880                if let Some(new_tr) = self.process_var(var) {
881                    *kind = new_tr.kind;
882                }
883            }
884            _ => (),
885        }
886    }
887}
888
889/// Types that are involved at the type-level and may be substituted around.
890pub trait TyVisitable: Sized + AstVisitable {
891    fn substitute(self, generics: &GenericArgs) -> Self {
892        self.substitute_with_self(generics, &TraitRefKind::SelfId)
893    }
894
895    fn substitute_with_self(mut self, generics: &GenericArgs, self_ref: &TraitRefKind) -> Self {
896        let _ = self.drive_mut(&mut SubstVisitor::new(generics, self_ref));
897        self
898    }
899
900    /// Move under one binder.
901    fn move_under_binder(self) -> Self {
902        self.move_under_binders(DeBruijnId::one())
903    }
904
905    /// Move under `depth` binders.
906    fn move_under_binders(mut self, depth: DeBruijnId) -> Self {
907        if !depth.is_zero() {
908            let Continue(()) = self.visit_db_id::<Infallible>(|id| {
909                *id = id.plus(depth);
910                Continue(())
911            });
912        }
913        self
914    }
915
916    /// Move from under one binder.
917    fn move_from_under_binder(self) -> Option<Self> {
918        self.move_from_under_binders(DeBruijnId::one())
919    }
920
921    /// Move the value out of `depth` binders. Returns `None` if it contains a variable bound in
922    /// one of these `depth` binders.
923    fn move_from_under_binders(mut self, depth: DeBruijnId) -> Option<Self> {
924        self.visit_db_id::<()>(|id| match id.sub(depth) {
925            Some(sub) => {
926                *id = sub;
927                Continue(())
928            }
929            None => Break(()),
930        })
931        .is_continue()
932        .then_some(self)
933    }
934
935    /// Visit the de Bruijn ids contained in `self`, as seen from the outside of `self`. This means
936    /// that any variable bound inside `self` will be skipped, and all the seen indices will count
937    /// from the outside of self.
938    fn visit_db_id<B>(
939        &mut self,
940        f: impl FnMut(&mut DeBruijnId) -> ControlFlow<B>,
941    ) -> ControlFlow<B> {
942        struct Wrap<F> {
943            f: F,
944            depth: DeBruijnId,
945        }
946        impl<B, F> Visitor for Wrap<F>
947        where
948            F: FnMut(&mut DeBruijnId) -> ControlFlow<B>,
949        {
950            type Break = B;
951        }
952        impl<B, F> VisitAstMut for Wrap<F>
953        where
954            F: FnMut(&mut DeBruijnId) -> ControlFlow<B>,
955        {
956            fn enter_region_binder<T: AstVisitable>(&mut self, _: &mut RegionBinder<T>) {
957                self.depth = self.depth.incr()
958            }
959            fn exit_region_binder<T: AstVisitable>(&mut self, _: &mut RegionBinder<T>) {
960                self.depth = self.depth.decr()
961            }
962            fn enter_binder<T: AstVisitable>(&mut self, _: &mut Binder<T>) {
963                self.depth = self.depth.incr()
964            }
965            fn exit_binder<T: AstVisitable>(&mut self, _: &mut Binder<T>) {
966                self.depth = self.depth.decr()
967            }
968
969            fn visit_de_bruijn_id(&mut self, x: &mut DeBruijnId) -> ControlFlow<Self::Break> {
970                if let Some(mut shifted) = x.sub(self.depth) {
971                    (self.f)(&mut shifted)?;
972                    *x = shifted.plus(self.depth)
973                }
974                Continue(())
975            }
976        }
977        self.drive_mut(&mut Wrap {
978            f,
979            depth: DeBruijnId::zero(),
980        })
981    }
982}
983
984impl<T: AstVisitable> TyVisitable for T {}
985
986impl Eq for TraitClause {}
987
988mk_index_impls!(GenericArgs.regions[RegionId]: Region);
989mk_index_impls!(GenericArgs.types[TypeVarId]: Ty);
990mk_index_impls!(GenericArgs.const_generics[ConstGenericVarId]: ConstGeneric);
991mk_index_impls!(GenericArgs.trait_refs[TraitClauseId]: TraitRef);
992mk_index_impls!(GenericParams.regions[RegionId]: RegionVar);
993mk_index_impls!(GenericParams.types[TypeVarId]: TypeVar);
994mk_index_impls!(GenericParams.const_generics[ConstGenericVarId]: ConstGenericVar);
995mk_index_impls!(GenericParams.trait_clauses[TraitClauseId]: TraitClause);