rustc_trait_selection/traits/
util.rs

1use std::collections::VecDeque;
2
3use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
4use rustc_hir::LangItem;
5use rustc_hir::def_id::DefId;
6use rustc_infer::infer::InferCtxt;
7use rustc_infer::traits::PolyTraitObligation;
8pub use rustc_infer::traits::util::*;
9use rustc_middle::bug;
10use rustc_middle::ty::fast_reject::DeepRejectCtxt;
11use rustc_middle::ty::{
12    self, PolyTraitPredicate, SizedTraitKind, TraitPredicate, TraitRef, Ty, TyCtxt, TypeFoldable,
13    TypeFolder, TypeSuperFoldable, TypeVisitableExt,
14};
15pub use rustc_next_trait_solver::placeholder::BoundVarReplacer;
16use rustc_span::Span;
17use smallvec::{SmallVec, smallvec};
18use tracing::debug;
19
20/// Return the trait and projection predicates that come from eagerly expanding the
21/// trait aliases in the list of clauses. For each trait predicate, record a stack
22/// of spans that trace from the user-written trait alias bound. For projection predicates,
23/// just record the span of the projection itself.
24///
25/// For trait aliases, we don't deduplicte the predicates, since we currently do not
26/// consider duplicated traits as a single trait for the purposes of our "one trait principal"
27/// restriction; however, for projections we do deduplicate them.
28///
29/// ```rust,ignore (fails)
30/// trait Bar {}
31/// trait Foo = Bar + Bar;
32///
33/// let dyn_incompatible: dyn Foo; // bad, two `Bar` principals.
34/// ```
35pub fn expand_trait_aliases<'tcx>(
36    tcx: TyCtxt<'tcx>,
37    clauses: impl IntoIterator<Item = (ty::Clause<'tcx>, Span)>,
38) -> (
39    Vec<(ty::PolyTraitPredicate<'tcx>, SmallVec<[Span; 1]>)>,
40    Vec<(ty::PolyProjectionPredicate<'tcx>, Span)>,
41) {
42    let mut trait_preds = vec![];
43    let mut projection_preds = vec![];
44    let mut seen_projection_preds = FxHashSet::default();
45
46    let mut queue: VecDeque<_> = clauses.into_iter().map(|(p, s)| (p, smallvec![s])).collect();
47
48    while let Some((clause, spans)) = queue.pop_front() {
49        match clause.kind().skip_binder() {
50            ty::ClauseKind::Trait(trait_pred) => {
51                if tcx.is_trait_alias(trait_pred.def_id()) {
52                    queue.extend(
53                        tcx.explicit_super_predicates_of(trait_pred.def_id())
54                            .iter_identity_copied()
55                            .map(|(super_clause, span)| {
56                                let mut spans = spans.clone();
57                                spans.push(span);
58                                (
59                                    super_clause.instantiate_supertrait(
60                                        tcx,
61                                        clause.kind().rebind(trait_pred.trait_ref),
62                                    ),
63                                    spans,
64                                )
65                            }),
66                    );
67                } else {
68                    trait_preds.push((clause.kind().rebind(trait_pred), spans));
69                }
70            }
71            ty::ClauseKind::Projection(projection_pred) => {
72                let projection_pred = clause.kind().rebind(projection_pred);
73                if !seen_projection_preds.insert(tcx.anonymize_bound_vars(projection_pred)) {
74                    continue;
75                }
76                projection_preds.push((projection_pred, *spans.last().unwrap()));
77            }
78            ty::ClauseKind::RegionOutlives(..)
79            | ty::ClauseKind::TypeOutlives(..)
80            | ty::ClauseKind::ConstArgHasType(_, _)
81            | ty::ClauseKind::WellFormed(_)
82            | ty::ClauseKind::ConstEvaluatable(_)
83            | ty::ClauseKind::HostEffect(..) => {}
84        }
85    }
86
87    (trait_preds, projection_preds)
88}
89
90///////////////////////////////////////////////////////////////////////////
91// Other
92///////////////////////////////////////////////////////////////////////////
93
94/// Casts a trait reference into a reference to one of its super
95/// traits; returns `None` if `target_trait_def_id` is not a
96/// supertrait.
97pub fn upcast_choices<'tcx>(
98    tcx: TyCtxt<'tcx>,
99    source_trait_ref: ty::PolyTraitRef<'tcx>,
100    target_trait_def_id: DefId,
101) -> Vec<ty::PolyTraitRef<'tcx>> {
102    if source_trait_ref.def_id() == target_trait_def_id {
103        return vec![source_trait_ref]; // Shortcut the most common case.
104    }
105
106    supertraits(tcx, source_trait_ref).filter(|r| r.def_id() == target_trait_def_id).collect()
107}
108
109pub(crate) fn closure_trait_ref_and_return_type<'tcx>(
110    tcx: TyCtxt<'tcx>,
111    fn_trait_def_id: DefId,
112    self_ty: Ty<'tcx>,
113    sig: ty::PolyFnSig<'tcx>,
114    tuple_arguments: TupleArgumentsFlag,
115) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> {
116    assert!(!self_ty.has_escaping_bound_vars());
117    let arguments_tuple = match tuple_arguments {
118        TupleArgumentsFlag::No => sig.skip_binder().inputs()[0],
119        TupleArgumentsFlag::Yes => Ty::new_tup(tcx, sig.skip_binder().inputs()),
120    };
121    let trait_ref = ty::TraitRef::new(tcx, fn_trait_def_id, [self_ty, arguments_tuple]);
122    sig.map_bound(|sig| (trait_ref, sig.output()))
123}
124
125pub(crate) fn coroutine_trait_ref_and_outputs<'tcx>(
126    tcx: TyCtxt<'tcx>,
127    fn_trait_def_id: DefId,
128    self_ty: Ty<'tcx>,
129    sig: ty::GenSig<TyCtxt<'tcx>>,
130) -> (ty::TraitRef<'tcx>, Ty<'tcx>, Ty<'tcx>) {
131    assert!(!self_ty.has_escaping_bound_vars());
132    let trait_ref = ty::TraitRef::new(tcx, fn_trait_def_id, [self_ty, sig.resume_ty]);
133    (trait_ref, sig.yield_ty, sig.return_ty)
134}
135
136pub(crate) fn future_trait_ref_and_outputs<'tcx>(
137    tcx: TyCtxt<'tcx>,
138    fn_trait_def_id: DefId,
139    self_ty: Ty<'tcx>,
140    sig: ty::GenSig<TyCtxt<'tcx>>,
141) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
142    assert!(!self_ty.has_escaping_bound_vars());
143    let trait_ref = ty::TraitRef::new(tcx, fn_trait_def_id, [self_ty]);
144    (trait_ref, sig.return_ty)
145}
146
147pub(crate) fn iterator_trait_ref_and_outputs<'tcx>(
148    tcx: TyCtxt<'tcx>,
149    iterator_def_id: DefId,
150    self_ty: Ty<'tcx>,
151    sig: ty::GenSig<TyCtxt<'tcx>>,
152) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
153    assert!(!self_ty.has_escaping_bound_vars());
154    let trait_ref = ty::TraitRef::new(tcx, iterator_def_id, [self_ty]);
155    (trait_ref, sig.yield_ty)
156}
157
158pub(crate) fn async_iterator_trait_ref_and_outputs<'tcx>(
159    tcx: TyCtxt<'tcx>,
160    async_iterator_def_id: DefId,
161    self_ty: Ty<'tcx>,
162    sig: ty::GenSig<TyCtxt<'tcx>>,
163) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
164    assert!(!self_ty.has_escaping_bound_vars());
165    let trait_ref = ty::TraitRef::new(tcx, async_iterator_def_id, [self_ty]);
166    (trait_ref, sig.yield_ty)
167}
168
169pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool {
170    assoc_item.defaultness(tcx).is_final()
171        && tcx.defaultness(assoc_item.container_id(tcx)).is_final()
172}
173
174pub(crate) enum TupleArgumentsFlag {
175    Yes,
176    No,
177}
178
179/// Executes `f` on `value` after replacing all escaping bound variables with placeholders
180/// and then replaces these placeholders with the original bound variables in the result.
181///
182/// In most places, bound variables should be replaced right when entering a binder, making
183/// this function unnecessary. However, normalization currently does not do that, so we have
184/// to do this lazily.
185///
186/// You should not add any additional uses of this function, at least not without first
187/// discussing it with t-types.
188///
189/// FIXME(@lcnr): We may even consider experimenting with eagerly replacing bound vars during
190/// normalization as well, at which point this function will be unnecessary and can be removed.
191pub fn with_replaced_escaping_bound_vars<
192    'a,
193    'tcx,
194    T: TypeFoldable<TyCtxt<'tcx>>,
195    R: TypeFoldable<TyCtxt<'tcx>>,
196>(
197    infcx: &'a InferCtxt<'tcx>,
198    universe_indices: &'a mut Vec<Option<ty::UniverseIndex>>,
199    value: T,
200    f: impl FnOnce(T) -> R,
201) -> R {
202    if value.has_escaping_bound_vars() {
203        let (value, mapped_regions, mapped_types, mapped_consts) =
204            BoundVarReplacer::replace_bound_vars(infcx, universe_indices, value);
205        let result = f(value);
206        PlaceholderReplacer::replace_placeholders(
207            infcx,
208            mapped_regions,
209            mapped_types,
210            mapped_consts,
211            universe_indices,
212            result,
213        )
214    } else {
215        f(value)
216    }
217}
218
219/// The inverse of [`BoundVarReplacer`]: replaces placeholders with the bound vars from which they came.
220pub struct PlaceholderReplacer<'a, 'tcx> {
221    infcx: &'a InferCtxt<'tcx>,
222    mapped_regions: FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion>,
223    mapped_types: FxIndexMap<ty::PlaceholderType, ty::BoundTy>,
224    mapped_consts: FxIndexMap<ty::PlaceholderConst, ty::BoundVar>,
225    universe_indices: &'a [Option<ty::UniverseIndex>],
226    current_index: ty::DebruijnIndex,
227}
228
229impl<'a, 'tcx> PlaceholderReplacer<'a, 'tcx> {
230    pub fn replace_placeholders<T: TypeFoldable<TyCtxt<'tcx>>>(
231        infcx: &'a InferCtxt<'tcx>,
232        mapped_regions: FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion>,
233        mapped_types: FxIndexMap<ty::PlaceholderType, ty::BoundTy>,
234        mapped_consts: FxIndexMap<ty::PlaceholderConst, ty::BoundVar>,
235        universe_indices: &'a [Option<ty::UniverseIndex>],
236        value: T,
237    ) -> T {
238        let mut replacer = PlaceholderReplacer {
239            infcx,
240            mapped_regions,
241            mapped_types,
242            mapped_consts,
243            universe_indices,
244            current_index: ty::INNERMOST,
245        };
246        value.fold_with(&mut replacer)
247    }
248}
249
250impl<'tcx> TypeFolder<TyCtxt<'tcx>> for PlaceholderReplacer<'_, 'tcx> {
251    fn cx(&self) -> TyCtxt<'tcx> {
252        self.infcx.tcx
253    }
254
255    fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
256        &mut self,
257        t: ty::Binder<'tcx, T>,
258    ) -> ty::Binder<'tcx, T> {
259        if !t.has_placeholders() && !t.has_infer() {
260            return t;
261        }
262        self.current_index.shift_in(1);
263        let t = t.super_fold_with(self);
264        self.current_index.shift_out(1);
265        t
266    }
267
268    fn fold_region(&mut self, r0: ty::Region<'tcx>) -> ty::Region<'tcx> {
269        let r1 = match r0.kind() {
270            ty::ReVar(vid) => self
271                .infcx
272                .inner
273                .borrow_mut()
274                .unwrap_region_constraints()
275                .opportunistic_resolve_var(self.infcx.tcx, vid),
276            _ => r0,
277        };
278
279        let r2 = match r1.kind() {
280            ty::RePlaceholder(p) => {
281                let replace_var = self.mapped_regions.get(&p);
282                match replace_var {
283                    Some(replace_var) => {
284                        let index = self
285                            .universe_indices
286                            .iter()
287                            .position(|u| matches!(u, Some(pu) if *pu == p.universe))
288                            .unwrap_or_else(|| bug!("Unexpected placeholder universe."));
289                        let db = ty::DebruijnIndex::from_usize(
290                            self.universe_indices.len() - index + self.current_index.as_usize() - 1,
291                        );
292                        ty::Region::new_bound(self.cx(), db, *replace_var)
293                    }
294                    None => r1,
295                }
296            }
297            _ => r1,
298        };
299
300        debug!(?r0, ?r1, ?r2, "fold_region");
301
302        r2
303    }
304
305    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
306        let ty = self.infcx.shallow_resolve(ty);
307        match *ty.kind() {
308            ty::Placeholder(p) => {
309                let replace_var = self.mapped_types.get(&p);
310                match replace_var {
311                    Some(replace_var) => {
312                        let index = self
313                            .universe_indices
314                            .iter()
315                            .position(|u| matches!(u, Some(pu) if *pu == p.universe))
316                            .unwrap_or_else(|| bug!("Unexpected placeholder universe."));
317                        let db = ty::DebruijnIndex::from_usize(
318                            self.universe_indices.len() - index + self.current_index.as_usize() - 1,
319                        );
320                        Ty::new_bound(self.infcx.tcx, db, *replace_var)
321                    }
322                    None => {
323                        if ty.has_infer() {
324                            ty.super_fold_with(self)
325                        } else {
326                            ty
327                        }
328                    }
329                }
330            }
331
332            _ if ty.has_placeholders() || ty.has_infer() => ty.super_fold_with(self),
333            _ => ty,
334        }
335    }
336
337    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
338        let ct = self.infcx.shallow_resolve_const(ct);
339        if let ty::ConstKind::Placeholder(p) = ct.kind() {
340            let replace_var = self.mapped_consts.get(&p);
341            match replace_var {
342                Some(replace_var) => {
343                    let index = self
344                        .universe_indices
345                        .iter()
346                        .position(|u| matches!(u, Some(pu) if *pu == p.universe))
347                        .unwrap_or_else(|| bug!("Unexpected placeholder universe."));
348                    let db = ty::DebruijnIndex::from_usize(
349                        self.universe_indices.len() - index + self.current_index.as_usize() - 1,
350                    );
351                    ty::Const::new_bound(self.infcx.tcx, db, *replace_var)
352                }
353                None => {
354                    if ct.has_infer() {
355                        ct.super_fold_with(self)
356                    } else {
357                        ct
358                    }
359                }
360            }
361        } else {
362            ct.super_fold_with(self)
363        }
364    }
365}
366
367pub fn sizedness_fast_path<'tcx>(tcx: TyCtxt<'tcx>, predicate: ty::Predicate<'tcx>) -> bool {
368    // Proving `Sized`/`MetaSized`, very often on "obviously sized" types like
369    // `&T`, accounts for about 60% percentage of the predicates we have to prove. No need to
370    // canonicalize and all that for such cases.
371    if let ty::PredicateKind::Clause(ty::ClauseKind::Trait(trait_ref)) =
372        predicate.kind().skip_binder()
373    {
374        let sizedness = match tcx.as_lang_item(trait_ref.def_id()) {
375            Some(LangItem::Sized) => SizedTraitKind::Sized,
376            Some(LangItem::MetaSized) => SizedTraitKind::MetaSized,
377            _ => return false,
378        };
379
380        if trait_ref.self_ty().has_trivial_sizedness(tcx, sizedness) {
381            debug!("fast path -- trivial sizedness");
382            return true;
383        }
384    }
385
386    false
387}
388
389/// To improve performance, sizedness traits are not elaborated and so special-casing is required
390/// in the trait solver to find a `Sized` candidate for a `MetaSized` obligation. Returns the
391/// predicate to used in the candidate for such a `obligation`, given a `candidate`.
392pub(crate) fn lazily_elaborate_sizedness_candidate<'tcx>(
393    infcx: &InferCtxt<'tcx>,
394    obligation: &PolyTraitObligation<'tcx>,
395    candidate: PolyTraitPredicate<'tcx>,
396) -> PolyTraitPredicate<'tcx> {
397    if !infcx.tcx.is_lang_item(obligation.predicate.def_id(), LangItem::MetaSized)
398        || !infcx.tcx.is_lang_item(candidate.def_id(), LangItem::Sized)
399    {
400        return candidate;
401    }
402
403    if obligation.predicate.polarity() != candidate.polarity() {
404        return candidate;
405    }
406
407    let drcx = DeepRejectCtxt::relate_rigid_rigid(infcx.tcx);
408    if !drcx.args_may_unify(
409        obligation.predicate.skip_binder().trait_ref.args,
410        candidate.skip_binder().trait_ref.args,
411    ) {
412        return candidate;
413    }
414
415    candidate.map_bound(|c| TraitPredicate {
416        trait_ref: TraitRef::new_from_args(
417            infcx.tcx,
418            obligation.predicate.def_id(),
419            c.trait_ref.args,
420        ),
421        polarity: c.polarity,
422    })
423}