charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{AstFormatter, FmtCtx, IntoFormatter},
11    pretty::FmtWithCtx,
12};
13
14use super::{ctx::TransformPass, TransformCtx};
15
16#[derive(Visitor)]
17struct CheckGenericsVisitor<'a> {
18    ctx: &'a TransformCtx,
19    phase: &'static str,
20    /// Tracks an enclosing span for error reporting.
21    span: Span,
22    /// Track the binders seen so far.
23    // We can't keep the params by reference because the visitors don't tell us that everything
24    // we're visiting has lifetime `'a`.
25    binder_stack: BindingStack<GenericParams>,
26    /// Remember the names of the types visited up to here.
27    visit_stack: Vec<&'static str>,
28}
29
30impl CheckGenericsVisitor<'_> {
31    fn error(&self, message: impl Display) {
32        let msg = format!(
33            "Found inconsistent generics {}:\n{message}\n\
34            Visitor stack:\n  {}\n\
35            Binding stack (depth {}):\n  {}",
36            self.phase,
37            self.visit_stack.iter().rev().join("\n  "),
38            self.binder_stack.len(),
39            self.binder_stack
40                .iter_enumerated()
41                .map(|(i, params)| format!("{i}: {params}"))
42                .join("\n  "),
43        );
44        // This is a fatal error: the output llbc is inconsistent and should not be used.
45        self.ctx.span_err(self.span, &msg, Level::ERROR);
46    }
47
48    /// For pretty error printing. This can print values that we encounter because we track binders
49    /// properly. This doesn't have the right binders to print values we get from somewhere else
50    /// (namely the `GenericParam`s we get from elsewhere in the crate).
51    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
52        let mut fmt = self.ctx.into_fmt();
53        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
54        fmt
55    }
56
57    fn zip_assert_match<I, A, B, FmtA, FmtB>(
58        &self,
59        a: &Vector<I, A>,
60        b: &Vector<I, B>,
61        a_fmt: &FmtA,
62        b_fmt: &FmtB,
63        kind: &str,
64        target: &GenericsSource,
65        check_inner: impl Fn(&A, &B),
66    ) where
67        I: Idx,
68        FmtA: AstFormatter,
69        A: FmtWithCtx<FmtA>,
70        B: FmtWithCtx<FmtB>,
71    {
72        if a.elem_count() == b.elem_count() {
73            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
74        } else {
75            let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
76            let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
77            let target = target.with_ctx(a_fmt);
78            self.error(format!(
79                "Mismatched {kind}:\
80                \ntarget: {target}\
81                \nexpected: [{a}]\
82                \n     got: [{b}]"
83            ))
84        }
85    }
86
87    fn assert_clause_matches(
88        &self,
89        params_fmt: &FmtCtx<'_>,
90        tclause: &TraitClause,
91        tref: &TraitRef,
92    ) {
93        let clause_trait_id = tclause.trait_.skip_binder.trait_id;
94        let ref_trait_id = tref.trait_decl_ref.skip_binder.trait_id;
95        if clause_trait_id != ref_trait_id {
96            let args_fmt = &self.val_fmt_ctx();
97            let tclause = tclause.with_ctx(params_fmt);
98            let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
99            let tref = tref.with_ctx(args_fmt);
100            self.error(format!(
101                "Mismatched trait clause:\
102                \nexpected: {tclause}\
103                \n     got: {tref}: {tref_pred}"
104            ));
105        }
106    }
107
108    fn assert_matches(&self, params_fmt: &FmtCtx<'_>, params: &GenericParams, args: &GenericArgs) {
109        let args_fmt = &self.val_fmt_ctx();
110        self.zip_assert_match(
111            &params.regions,
112            &args.regions,
113            params_fmt,
114            args_fmt,
115            "regions",
116            &args.target,
117            |_, _| {},
118        );
119        self.zip_assert_match(
120            &params.types,
121            &args.types,
122            params_fmt,
123            args_fmt,
124            "type generics",
125            &args.target,
126            |_, _| {},
127        );
128        self.zip_assert_match(
129            &params.const_generics,
130            &args.const_generics,
131            params_fmt,
132            args_fmt,
133            "const generics",
134            &args.target,
135            |_, _| {},
136        );
137        self.zip_assert_match(
138            &params.trait_clauses,
139            &args.trait_refs,
140            params_fmt,
141            args_fmt,
142            "trait clauses",
143            &args.target,
144            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
145        );
146    }
147}
148
149impl VisitAst for CheckGenericsVisitor<'_> {
150    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
151        self.visit_stack.push(x.name());
152        x.drive(self)?; // default behavior
153        self.visit_stack.pop();
154        Continue(())
155    }
156
157    fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
158        self.binder_stack.push(binder.params.clone());
159        self.visit_inner(binder)?;
160        self.binder_stack.pop();
161        Continue(())
162    }
163    fn visit_region_binder<T: AstVisitable>(
164        &mut self,
165        binder: &RegionBinder<T>,
166    ) -> ControlFlow<Self::Break> {
167        self.binder_stack.push(GenericParams {
168            regions: binder.regions.clone(),
169            ..Default::default()
170        });
171        self.visit_inner(binder)?;
172        self.binder_stack.pop();
173        Continue(())
174    }
175
176    fn enter_region(&mut self, x: &Region) {
177        if let Region::Var(var) = x {
178            if self.binder_stack.get_var(*var).is_none() {
179                self.error(format!("Found incorrect region var: {var}"));
180            }
181        }
182    }
183    fn enter_ty_kind(&mut self, x: &TyKind) {
184        if let TyKind::TypeVar(var) = x {
185            if self.binder_stack.get_var(*var).is_none() {
186                self.error(format!("Found incorrect type var: {var}"));
187            }
188        }
189    }
190    fn enter_const_generic(&mut self, x: &ConstGeneric) {
191        if let ConstGeneric::Var(var) = x {
192            if self.binder_stack.get_var(*var).is_none() {
193                self.error(format!("Found incorrect const-generic var: {var}"));
194            }
195        }
196    }
197    fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
198        if let TraitRefKind::Clause(var) = x {
199            if self.binder_stack.get_var(*var).is_none() {
200                self.error(format!("Found incorrect clause var: {var}"));
201            }
202        }
203    }
204
205    fn visit_aggregate_kind(&mut self, agg: &AggregateKind) -> ControlFlow<Self::Break> {
206        match agg {
207            AggregateKind::Adt(..) | AggregateKind::Array(..) | AggregateKind::RawPtr(..) => {
208                self.visit_inner(agg)?
209            }
210        }
211        Continue(())
212    }
213
214    fn enter_generic_args(&mut self, args: &GenericArgs) {
215        let fmt1;
216        let fmt2;
217        let (params, params_fmt) = match &args.target {
218            GenericsSource::Item(item_id) => {
219                let Some(item) = self.ctx.translated.get_item(*item_id) else {
220                    return;
221                };
222                let params = item.generic_params();
223                fmt1 = self.ctx.into_fmt();
224                let fmt = fmt1.push_binder(Cow::Borrowed(params));
225                (params, fmt)
226            }
227            GenericsSource::Method(trait_id, method_name) => {
228                let Some(trait_decl) = self.ctx.translated.trait_decls.get(*trait_id) else {
229                    return;
230                };
231                let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name)
232                else {
233                    return;
234                };
235                let params = &bound_fn.params;
236                fmt1 = self.ctx.into_fmt();
237                fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
238                let fmt = fmt2.push_binder(Cow::Borrowed(params));
239                (params, fmt)
240            }
241            GenericsSource::Builtin => return,
242            GenericsSource::Other => {
243                self.error("`GenericsSource::Other` should not exist in the charon AST");
244                return;
245            }
246        };
247        self.assert_matches(&params_fmt, params, args);
248    }
249
250    // Special case that is not represented as a `GenericArgs`.
251    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
252        let Some(tdecl) = self
253            .ctx
254            .translated
255            .trait_decls
256            .get(timpl.impl_trait.trait_id)
257        else {
258            return;
259        };
260        // See `lift_associated_item_clauses`
261        assert!(timpl.type_clauses.is_empty());
262        assert!(tdecl.type_clauses.is_empty());
263
264        let fmt1 = self.ctx.into_fmt();
265        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
266        let args_fmt = &self.val_fmt_ctx();
267        self.zip_assert_match(
268            &tdecl.parent_clauses,
269            &timpl.parent_trait_refs,
270            &tdecl_fmt,
271            args_fmt,
272            "trait parent clauses",
273            &GenericsSource::item(timpl.impl_trait.trait_id),
274            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
275        );
276        let types_match = timpl.types.len() == tdecl.types.len()
277            && tdecl
278                .types
279                .iter()
280                .zip(timpl.types.iter())
281                .all(|(dname, (iname, _))| dname == iname);
282        if !types_match {
283            self.error(
284                "The associated types supplied by the trait impl don't match the trait decl.",
285            )
286        }
287        let consts_match = timpl.consts.len() == tdecl.consts.len()
288            && tdecl
289                .types
290                .iter()
291                .zip(timpl.types.iter())
292                .all(|(dname, (iname, _))| dname == iname);
293        if !consts_match {
294            self.error(
295                "The associated consts supplied by the trait impl don't match the trait decl.",
296            )
297        }
298        let methods_match = timpl.methods.len() == tdecl.methods.len();
299        if !methods_match && self.phase != "after translation" {
300            let decl_methods = tdecl
301                .methods()
302                .map(|(name, _)| format!("- {name}"))
303                .join("\n");
304            let impl_methods = timpl
305                .methods()
306                .map(|(name, _)| format!("- {name}"))
307                .join("\n");
308            self.error(format!(
309                "The methods supplied by the trait impl don't match the trait decl.\n\
310                Trait methods:\n{decl_methods}\n\
311                Impl methods:\n{impl_methods}"
312            ))
313        }
314    }
315
316    fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
317        // Track span for more precise error messages.
318        let old_span = self.span;
319        self.span = st.span;
320        self.visit_inner(st)?;
321        self.span = old_span;
322        Continue(())
323    }
324
325    fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
326        // Track span for more precise error messages.
327        let old_span = self.span;
328        self.span = st.span;
329        self.visit_inner(st)?;
330        self.span = old_span;
331        Continue(())
332    }
333}
334
335// The argument is a name to disambiguate the two times we run this check.
336pub struct Check(pub &'static str);
337impl TransformPass for Check {
338    fn transform_ctx(&self, ctx: &mut TransformCtx) {
339        for item in ctx.translated.all_items() {
340            let mut visitor = CheckGenericsVisitor {
341                ctx,
342                phase: self.0,
343                span: item.item_meta().span,
344                binder_stack: BindingStack::new(item.generic_params().clone()),
345                visit_stack: Default::default(),
346            };
347            let _ = item.drive(&mut visitor);
348        }
349    }
350}