rustc_infer/infer/canonical/
instantiate.rs

1//! This module contains code to instantiate new values into a
2//! `Canonical<'tcx, T>`.
3//!
4//! For an overview of what canonicalization is and how it fits into
5//! rustc, check out the [chapter in the rustc dev guide][c].
6//!
7//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
8
9use rustc_macros::extension;
10use rustc_middle::ty::{
11    self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable,
12    TypeVisitableExt, TypeVisitor,
13};
14use rustc_type_ir::TypeVisitable;
15
16use crate::infer::canonical::{Canonical, CanonicalVarValues};
17
18/// FIXME(-Znext-solver): This or public because it is shared with the
19/// new trait solver implementation. We should deduplicate canonicalization.
20#[extension(pub trait CanonicalExt<'tcx, V>)]
21impl<'tcx, V> Canonical<'tcx, V> {
22    /// Instantiate the wrapped value, replacing each canonical value
23    /// with the value given in `var_values`.
24    fn instantiate(&self, tcx: TyCtxt<'tcx>, var_values: &CanonicalVarValues<'tcx>) -> V
25    where
26        V: TypeFoldable<TyCtxt<'tcx>>,
27    {
28        self.instantiate_projected(tcx, var_values, |value| value.clone())
29    }
30
31    /// Allows one to apply a instantiation to some subset of
32    /// `self.value`. Invoke `projection_fn` with `self.value` to get
33    /// a value V that is expressed in terms of the same canonical
34    /// variables bound in `self` (usually this extracts from subset
35    /// of `self`). Apply the instantiation `var_values` to this value
36    /// V, replacing each of the canonical variables.
37    fn instantiate_projected<T>(
38        &self,
39        tcx: TyCtxt<'tcx>,
40        var_values: &CanonicalVarValues<'tcx>,
41        projection_fn: impl FnOnce(&V) -> T,
42    ) -> T
43    where
44        T: TypeFoldable<TyCtxt<'tcx>>,
45    {
46        assert_eq!(self.variables.len(), var_values.len());
47        let value = projection_fn(&self.value);
48        instantiate_value(tcx, var_values, value)
49    }
50}
51
52/// Instantiate the values from `var_values` into `value`. `var_values`
53/// must be values for the set of canonical variables that appear in
54/// `value`.
55pub(super) fn instantiate_value<'tcx, T>(
56    tcx: TyCtxt<'tcx>,
57    var_values: &CanonicalVarValues<'tcx>,
58    value: T,
59) -> T
60where
61    T: TypeFoldable<TyCtxt<'tcx>>,
62{
63    if var_values.var_values.is_empty() {
64        return value;
65    }
66
67    value.fold_with(&mut CanonicalInstantiator {
68        tcx,
69        current_index: ty::INNERMOST,
70        var_values: var_values.var_values,
71        cache: Default::default(),
72    })
73}
74
75/// Replaces the bound vars in a canonical binder with var values.
76struct CanonicalInstantiator<'tcx> {
77    tcx: TyCtxt<'tcx>,
78
79    // The values that the bound vars are are being instantiated with.
80    var_values: ty::GenericArgsRef<'tcx>,
81
82    /// As with `BoundVarReplacer`, represents the index of a binder *just outside*
83    /// the ones we have visited.
84    current_index: ty::DebruijnIndex,
85
86    // Instantiation is a pure function of `DebruijnIndex` and `Ty`.
87    cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
88}
89
90impl<'tcx> TypeFolder<TyCtxt<'tcx>> for CanonicalInstantiator<'tcx> {
91    fn cx(&self) -> TyCtxt<'tcx> {
92        self.tcx
93    }
94
95    fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
96        &mut self,
97        t: ty::Binder<'tcx, T>,
98    ) -> ty::Binder<'tcx, T> {
99        self.current_index.shift_in(1);
100        let t = t.super_fold_with(self);
101        self.current_index.shift_out(1);
102        t
103    }
104
105    fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
106        match *t.kind() {
107            ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
108                self.var_values[bound_ty.var.as_usize()].expect_ty()
109            }
110            _ => {
111                if !t.has_vars_bound_at_or_above(self.current_index) {
112                    t
113                } else if let Some(&t) = self.cache.get(&(self.current_index, t)) {
114                    t
115                } else {
116                    let res = t.super_fold_with(self);
117                    assert!(self.cache.insert((self.current_index, t), res));
118                    res
119                }
120            }
121        }
122    }
123
124    fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
125        match r.kind() {
126            ty::ReBound(debruijn, br) if debruijn == self.current_index => {
127                self.var_values[br.var.as_usize()].expect_region()
128            }
129            _ => r,
130        }
131    }
132
133    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
134        match ct.kind() {
135            ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
136                self.var_values[bound_const.as_usize()].expect_const()
137            }
138            _ => ct.super_fold_with(self),
139        }
140    }
141
142    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
143        if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
144    }
145
146    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
147        if !c.has_vars_bound_at_or_above(self.current_index) {
148            return c;
149        }
150
151        // Since instantiation is a function of `DebruijnIndex`, we don't want
152        // to have to cache more copies of clauses when we're inside of binders.
153        // Since we currently expect to only have clauses in the outermost
154        // debruijn index, we just fold if we're inside of a binder.
155        if self.current_index > ty::INNERMOST {
156            return c.super_fold_with(self);
157        }
158
159        // Our cache key is `(clauses, var_values)`, but we also don't care about
160        // var values that aren't named in the clauses, since they can change without
161        // affecting the output. Since `ParamEnv`s are cached first, we compute the
162        // last var value that is mentioned in the clauses, and cut off the list so
163        // that we have more hits in the cache.
164
165        // We also cache the computation of "highest var named by clauses" since that
166        // is both expensive (depending on the size of the clauses) and a pure function.
167        let index = *self
168            .tcx
169            .highest_var_in_clauses_cache
170            .lock()
171            .entry(c)
172            .or_insert_with(|| highest_var_in_clauses(c));
173        let c_args = &self.var_values[..=index];
174
175        if let Some(c) = self.tcx.clauses_cache.lock().get(&(c, c_args)) {
176            c
177        } else {
178            let folded = c.super_fold_with(self);
179            self.tcx.clauses_cache.lock().insert((c, c_args), folded);
180            folded
181        }
182    }
183}
184
185fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize {
186    struct HighestVarInClauses {
187        max_var: usize,
188        current_index: ty::DebruijnIndex,
189    }
190    impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HighestVarInClauses {
191        fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
192            &mut self,
193            t: &ty::Binder<'tcx, T>,
194        ) -> Self::Result {
195            self.current_index.shift_in(1);
196            let t = t.super_visit_with(self);
197            self.current_index.shift_out(1);
198            t
199        }
200        fn visit_ty(&mut self, t: Ty<'tcx>) {
201            if let ty::Bound(debruijn, bound_ty) = *t.kind()
202                && debruijn == self.current_index
203            {
204                self.max_var = self.max_var.max(bound_ty.var.as_usize());
205            } else if t.has_vars_bound_at_or_above(self.current_index) {
206                t.super_visit_with(self);
207            }
208        }
209        fn visit_region(&mut self, r: ty::Region<'tcx>) {
210            if let ty::ReBound(debruijn, bound_region) = r.kind()
211                && debruijn == self.current_index
212            {
213                self.max_var = self.max_var.max(bound_region.var.as_usize());
214            }
215        }
216        fn visit_const(&mut self, ct: ty::Const<'tcx>) {
217            if let ty::ConstKind::Bound(debruijn, bound_const) = ct.kind()
218                && debruijn == self.current_index
219            {
220                self.max_var = self.max_var.max(bound_const.as_usize());
221            } else if ct.has_vars_bound_at_or_above(self.current_index) {
222                ct.super_visit_with(self);
223            }
224        }
225    }
226    let mut visitor = HighestVarInClauses { max_var: 0, current_index: ty::INNERMOST };
227    c.visit_with(&mut visitor);
228    visitor.max_var
229}