rustc_trait_selection/traits/query/
evaluate_obligation.rs1use rustc_macros::extension;
2use rustc_middle::span_bug;
3
4use crate::infer::InferCtxt;
5use crate::infer::canonical::OriginalQueryValues;
6use crate::traits::{
7    EvaluationResult, ObligationCtxt, OverflowError, PredicateObligation, SelectionContext,
8};
9
10#[extension(pub trait InferCtxtExt<'tcx>)]
11impl<'tcx> InferCtxt<'tcx> {
12    fn predicate_may_hold(&self, obligation: &PredicateObligation<'tcx>) -> bool {
15        self.evaluate_obligation_no_overflow(obligation).may_apply()
16    }
17
18    fn predicate_must_hold_considering_regions(
46        &self,
47        obligation: &PredicateObligation<'tcx>,
48    ) -> bool {
49        self.evaluate_obligation_no_overflow(obligation).must_apply_considering_regions()
50    }
51
52    fn predicate_must_hold_modulo_regions(&self, obligation: &PredicateObligation<'tcx>) -> bool {
58        self.evaluate_obligation_no_overflow(obligation).must_apply_modulo_regions()
59    }
60
61    fn evaluate_obligation(
63        &self,
64        obligation: &PredicateObligation<'tcx>,
65    ) -> Result<EvaluationResult, OverflowError> {
66        let mut _orig_values = OriginalQueryValues::default();
67
68        let param_env = obligation.param_env;
69
70        if self.next_trait_solver() {
71            self.probe(|snapshot| {
72                let ocx = ObligationCtxt::new(self);
73                ocx.register_obligation(obligation.clone());
74                let mut result = EvaluationResult::EvaluatedToOk;
75                for error in ocx.select_all_or_error() {
76                    if error.is_true_error() {
77                        return Ok(EvaluationResult::EvaluatedToErr);
78                    } else {
79                        result = result.max(EvaluationResult::EvaluatedToAmbig);
80                    }
81                }
82                if self.opaque_types_added_in_snapshot(snapshot) {
83                    result = result.max(EvaluationResult::EvaluatedToOkModuloOpaqueTypes);
84                } else if self.region_constraints_added_in_snapshot(snapshot) {
85                    result = result.max(EvaluationResult::EvaluatedToOkModuloRegions);
86                }
87                Ok(result)
88            })
89        } else {
90            let c_pred =
91                self.canonicalize_query(param_env.and(obligation.predicate), &mut _orig_values);
92            self.tcx.at(obligation.cause.span).evaluate_obligation(c_pred)
93        }
94    }
95
96    fn evaluate_obligation_no_overflow(
100        &self,
101        obligation: &PredicateObligation<'tcx>,
102    ) -> EvaluationResult {
103        match self.evaluate_obligation(obligation) {
107            Ok(result) => result,
108            Err(OverflowError::Canonical) => {
109                let mut selcx = SelectionContext::new(self);
110                selcx.evaluate_root_obligation(obligation).unwrap_or_else(|r| match r {
111                    OverflowError::Canonical => {
112                        span_bug!(
113                            obligation.cause.span,
114                            "Overflow should be caught earlier in standard query mode: {:?}, {:?}",
115                            obligation,
116                            r,
117                        )
118                    }
119                    OverflowError::Error(_) => EvaluationResult::EvaluatedToErr,
120                })
121            }
122            Err(OverflowError::Error(_)) => EvaluationResult::EvaluatedToErr,
123        }
124    }
125}