charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{FmtCtx, IntoFormatter, PushBinder},
11    pretty::FmtWithCtx,
12};
13
14use super::{ctx::TransformPass, TransformCtx};
15
16#[derive(Visitor)]
17struct CheckGenericsVisitor<'a> {
18    ctx: &'a TransformCtx,
19    phase: &'static str,
20    /// Tracks an enclosing span for error reporting.
21    span: Span,
22    /// Track the binders seen so far.
23    // We can't keep the params by reference because the visitors don't tell us that everything
24    // we're visiting has lifetime `'a`.
25    binder_stack: BindingStack<GenericParams>,
26    /// Remember the names of the types visited up to here.
27    visit_stack: Vec<&'static str>,
28}
29
30impl CheckGenericsVisitor<'_> {
31    fn error(&self, message: impl Display) {
32        let msg = format!(
33            "Found inconsistent generics {}:\n{message}\n\
34            Visitor stack:\n  {}\n\
35            Binding stack (depth {}):\n  {}",
36            self.phase,
37            self.visit_stack.iter().rev().join("\n  "),
38            self.binder_stack.len(),
39            self.binder_stack
40                .iter_enumerated()
41                .map(|(i, params)| format!("{i}: {params}"))
42                .join("\n  "),
43        );
44        // This is a fatal error: the output llbc is inconsistent and should not be used.
45        self.ctx.span_err(self.span, &msg, Level::ERROR);
46    }
47
48    /// For pretty error printing. This can print values that we encounter because we track binders
49    /// properly. This doesn't have the right binders to print values we get from somewhere else
50    /// (namely the `GenericParam`s we get from elsewhere in the crate).
51    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
52        let mut fmt = self.ctx.into_fmt();
53        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
54        fmt
55    }
56
57    fn zip_assert_match<I, A, B, FmtA, FmtB>(
58        &self,
59        a: &Vector<I, A>,
60        b: &Vector<I, B>,
61        a_fmt: &FmtA,
62        b_fmt: &FmtB,
63        kind: &str,
64        check_inner: impl Fn(&A, &B),
65    ) where
66        I: Idx,
67        A: for<'a> FmtWithCtx<FmtA>,
68        B: for<'a> FmtWithCtx<FmtB>,
69    {
70        if a.elem_count() == b.elem_count() {
71            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
72        } else {
73            let a = a.iter().map(|x| x.fmt_with_ctx(a_fmt)).join(", ");
74            let b = b.iter().map(|x| x.fmt_with_ctx(b_fmt)).join(", ");
75            self.error(format!(
76                "Mismatched {kind}:\
77                \nexpected: [{a}]\
78                \n     got: [{b}]"
79            ))
80        }
81    }
82
83    fn assert_clause_matches(
84        &self,
85        params_fmt: &FmtCtx<'_>,
86        tclause: &TraitClause,
87        tref: &TraitRef,
88    ) {
89        let clause_trait_id = tclause.trait_.skip_binder.trait_id;
90        let ref_trait_id = tref.trait_decl_ref.skip_binder.trait_id;
91        if clause_trait_id != ref_trait_id {
92            let args_fmt = &self.val_fmt_ctx();
93            let tclause = tclause.fmt_with_ctx(params_fmt);
94            let tref_pred = tref.trait_decl_ref.fmt_with_ctx(args_fmt);
95            let tref = tref.fmt_with_ctx(args_fmt);
96            self.error(format!(
97                "Mismatched trait clause:\
98                \nexpected: {tclause}\
99                \n     got: {tref}: {tref_pred}"
100            ));
101        }
102    }
103
104    fn assert_matches(&self, params_fmt: &FmtCtx<'_>, params: &GenericParams, args: &GenericArgs) {
105        let args_fmt = &self.val_fmt_ctx();
106        self.zip_assert_match(
107            &params.regions,
108            &args.regions,
109            params_fmt,
110            args_fmt,
111            "regions",
112            |_, _| {},
113        );
114        self.zip_assert_match(
115            &params.types,
116            &args.types,
117            params_fmt,
118            args_fmt,
119            "type generics",
120            |_, _| {},
121        );
122        self.zip_assert_match(
123            &params.const_generics,
124            &args.const_generics,
125            params_fmt,
126            args_fmt,
127            "const generics",
128            |_, _| {},
129        );
130        self.zip_assert_match(
131            &params.trait_clauses,
132            &args.trait_refs,
133            params_fmt,
134            args_fmt,
135            "trait clauses",
136            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
137        );
138    }
139}
140
141impl VisitAst for CheckGenericsVisitor<'_> {
142    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
143        self.visit_stack.push(x.name());
144        x.drive(self)?; // default behavior
145        self.visit_stack.pop();
146        Continue(())
147    }
148
149    fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
150        self.binder_stack.push(binder.params.clone());
151        self.visit_inner(binder)?;
152        self.binder_stack.pop();
153        Continue(())
154    }
155    fn visit_region_binder<T: AstVisitable>(
156        &mut self,
157        binder: &RegionBinder<T>,
158    ) -> ControlFlow<Self::Break> {
159        self.binder_stack.push(GenericParams {
160            regions: binder.regions.clone(),
161            ..Default::default()
162        });
163        self.visit_inner(binder)?;
164        self.binder_stack.pop();
165        Continue(())
166    }
167
168    fn enter_region(&mut self, x: &Region) {
169        if let Region::Var(var) = x {
170            if self.binder_stack.get_var(*var).is_none() {
171                self.error(format!("Found incorrect region var: {var}"));
172            }
173        }
174    }
175    fn enter_ty_kind(&mut self, x: &TyKind) {
176        if let TyKind::TypeVar(var) = x {
177            if self.binder_stack.get_var(*var).is_none() {
178                self.error(format!("Found incorrect type var: {var}"));
179            }
180        }
181    }
182    fn enter_const_generic(&mut self, x: &ConstGeneric) {
183        if let ConstGeneric::Var(var) = x {
184            if self.binder_stack.get_var(*var).is_none() {
185                self.error(format!("Found incorrect const-generic var: {var}"));
186            }
187        }
188    }
189    fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
190        if let TraitRefKind::Clause(var) = x {
191            if self.binder_stack.get_var(*var).is_none() {
192                self.error(format!("Found incorrect clause var: {var}"));
193            }
194        }
195    }
196
197    fn visit_aggregate_kind(&mut self, agg: &AggregateKind) -> ControlFlow<Self::Break> {
198        match agg {
199            AggregateKind::Adt(..) | AggregateKind::Array(..) | AggregateKind::RawPtr(..) => {
200                self.visit_inner(agg)?
201            }
202            AggregateKind::Closure(_id, args) => {
203                // TODO(#194): handle closure generics properly
204                // This does not visit the args themselves, only their contents, because we mess up
205                // closure generics for now.
206                self.visit_inner(args)?
207            }
208        }
209        Continue(())
210    }
211
212    fn enter_generic_args(&mut self, args: &GenericArgs) {
213        let fmt1;
214        let fmt2;
215        let (params, params_fmt) = match &args.target {
216            GenericsSource::Item(item_id) => {
217                let Some(item) = self.ctx.translated.get_item(*item_id) else {
218                    return;
219                };
220                let params = item.generic_params();
221                fmt1 = self.ctx.into_fmt();
222                let fmt = fmt1.push_binder(Cow::Borrowed(params));
223                (params, fmt)
224            }
225            GenericsSource::Method(trait_id, method_name) => {
226                let Some(trait_decl) = self.ctx.translated.trait_decls.get(*trait_id) else {
227                    return;
228                };
229                let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name)
230                else {
231                    return;
232                };
233                let params = &bound_fn.params;
234                fmt1 = self.ctx.into_fmt();
235                fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
236                let fmt = fmt2.push_binder(Cow::Borrowed(params));
237                (params, fmt)
238            }
239            GenericsSource::Builtin => return,
240            GenericsSource::Other => {
241                self.error("`GenericsSource::Other` should not exist in the charon AST");
242                return;
243            }
244        };
245        self.assert_matches(&params_fmt, params, args);
246    }
247
248    // Special case that is not represented as a `GenericArgs`.
249    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
250        let Some(tdecl) = self
251            .ctx
252            .translated
253            .trait_decls
254            .get(timpl.impl_trait.trait_id)
255        else {
256            return;
257        };
258        // See `lift_associated_item_clauses`
259        assert!(timpl.type_clauses.is_empty());
260        assert!(tdecl.type_clauses.is_empty());
261
262        let fmt1 = self.ctx.into_fmt();
263        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
264        let args_fmt = &self.val_fmt_ctx();
265        self.zip_assert_match(
266            &tdecl.parent_clauses,
267            &timpl.parent_trait_refs,
268            &tdecl_fmt,
269            args_fmt,
270            "trait parent clauses",
271            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
272        );
273        let types_match = timpl.types.len() == tdecl.types.len()
274            && tdecl
275                .types
276                .iter()
277                .zip(timpl.types.iter())
278                .all(|(dname, (iname, _))| dname == iname);
279        if !types_match {
280            self.error(
281                "The associated types supplied by the trait impl don't match the trait decl.",
282            )
283        }
284        let consts_match = timpl.consts.len() == tdecl.consts.len()
285            && tdecl
286                .types
287                .iter()
288                .zip(timpl.types.iter())
289                .all(|(dname, (iname, _))| dname == iname);
290        if !consts_match {
291            self.error(
292                "The associated consts supplied by the trait impl don't match the trait decl.",
293            )
294        }
295        let methods_match = timpl.methods.len() == tdecl.methods.len();
296        if !methods_match && self.phase != "after translation" {
297            let decl_methods = tdecl
298                .methods()
299                .map(|(name, _)| format!("- {name}"))
300                .join("\n");
301            let impl_methods = timpl
302                .methods()
303                .map(|(name, _)| format!("- {name}"))
304                .join("\n");
305            self.error(format!(
306                "The methods supplied by the trait impl don't match the trait decl.\n\
307                Trait methods:\n{decl_methods}\n\
308                Impl methods:\n{impl_methods}"
309            ))
310        }
311    }
312
313    fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
314        // Track span for more precise error messages.
315        let old_span = self.span;
316        self.span = st.span;
317        self.visit_inner(st)?;
318        self.span = old_span;
319        Continue(())
320    }
321
322    fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
323        // Track span for more precise error messages.
324        let old_span = self.span;
325        self.span = st.span;
326        self.visit_inner(st)?;
327        self.span = old_span;
328        Continue(())
329    }
330}
331
332// The argument is a name to disambiguate the two times we run this check.
333pub struct Check(pub &'static str);
334impl TransformPass for Check {
335    fn transform_ctx(&self, ctx: &mut TransformCtx) {
336        for item in ctx.translated.all_items() {
337            let mut visitor = CheckGenericsVisitor {
338                ctx,
339                phase: self.0,
340                span: item.item_meta().span,
341                binder_stack: BindingStack::new(item.generic_params().clone()),
342                visit_stack: Default::default(),
343            };
344            let _ = item.drive(&mut visitor);
345        }
346    }
347}