charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{AstFormatter, FmtCtx, IntoFormatter},
11    pretty::FmtWithCtx,
12    transform::utils::GenericsSource,
13};
14
15use super::{ctx::TransformPass, TransformCtx};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19    ctx: &'a TransformCtx,
20    phase: &'static str,
21    /// Tracks an enclosing span for error reporting.
22    span: Span,
23    /// Track the binders seen so far.
24    // We can't keep the params by reference because the visitors don't tell us that everything
25    // we're visiting has lifetime `'a`.
26    binder_stack: BindingStack<GenericParams>,
27    /// Remember the names of the types visited up to here.
28    visit_stack: Vec<&'static str>,
29}
30
31impl CheckGenericsVisitor<'_> {
32    fn error(&self, message: impl Display) {
33        let msg = format!(
34            "Found inconsistent generics {}:\n{message}\n\
35            Visitor stack:\n  {}\n\
36            Binding stack (depth {}):\n  {}",
37            self.phase,
38            self.visit_stack.iter().rev().join("\n  "),
39            self.binder_stack.len(),
40            self.binder_stack
41                .iter_enumerated()
42                .map(|(i, params)| format!("{i}: {params}"))
43                .join("\n  "),
44        );
45        // This is a fatal error: the output llbc is inconsistent and should not be used.
46        self.ctx.span_err(self.span, &msg, Level::ERROR);
47    }
48
49    /// For pretty error printing. This can print values that we encounter because we track binders
50    /// properly. This doesn't have the right binders to print values we get from somewhere else
51    /// (namely the `GenericParam`s we get from elsewhere in the crate).
52    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
53        let mut fmt = self.ctx.into_fmt();
54        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
55        fmt
56    }
57
58    fn zip_assert_match<I, A, B, FmtA, FmtB>(
59        &self,
60        a: &Vector<I, A>,
61        b: &Vector<I, B>,
62        a_fmt: &FmtA,
63        b_fmt: &FmtB,
64        kind: &str,
65        target: &GenericsSource,
66        check_inner: impl Fn(&A, &B),
67    ) where
68        I: Idx,
69        FmtA: AstFormatter,
70        A: FmtWithCtx<FmtA>,
71        B: FmtWithCtx<FmtB>,
72    {
73        if a.elem_count() == b.elem_count() {
74            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
75        } else {
76            let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
77            let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
78            let target = target.with_ctx(a_fmt);
79            self.error(format!(
80                "Mismatched {kind}:\
81                \ntarget: {target}\
82                \nexpected: [{a}]\
83                \n     got: [{b}]"
84            ))
85        }
86    }
87
88    fn assert_clause_matches(
89        &self,
90        params_fmt: &FmtCtx<'_>,
91        tclause: &TraitClause,
92        tref: &TraitRef,
93    ) {
94        let clause_trait_id = tclause.trait_.skip_binder.id;
95        let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
96        if clause_trait_id != ref_trait_id {
97            let args_fmt = &self.val_fmt_ctx();
98            let tclause = tclause.with_ctx(params_fmt);
99            let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
100            let tref = tref.with_ctx(args_fmt);
101            self.error(format!(
102                "Mismatched trait clause:\
103                \nexpected: {tclause}\
104                \n     got: {tref}: {tref_pred}"
105            ));
106        }
107    }
108
109    fn assert_matches(
110        &self,
111        params_fmt: &FmtCtx<'_>,
112        params: &GenericParams,
113        args: &GenericArgs,
114        target: &GenericsSource,
115    ) {
116        let args_fmt = &self.val_fmt_ctx();
117        self.zip_assert_match(
118            &params.regions,
119            &args.regions,
120            params_fmt,
121            args_fmt,
122            "regions",
123            target,
124            |_, _| {},
125        );
126        self.zip_assert_match(
127            &params.types,
128            &args.types,
129            params_fmt,
130            args_fmt,
131            "type generics",
132            target,
133            |_, _| {},
134        );
135        self.zip_assert_match(
136            &params.const_generics,
137            &args.const_generics,
138            params_fmt,
139            args_fmt,
140            "const generics",
141            target,
142            |_, _| {},
143        );
144        self.zip_assert_match(
145            &params.trait_clauses,
146            &args.trait_refs,
147            params_fmt,
148            args_fmt,
149            "trait clauses",
150            target,
151            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
152        );
153    }
154
155    fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
156        let id = id.into();
157        let Some(item) = self.ctx.translated.get_item(id) else {
158            return;
159        };
160        let params = item.generic_params();
161        let fmt1 = self.ctx.into_fmt();
162        let fmt = fmt1.push_binder(Cow::Borrowed(params));
163        self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
164    }
165
166    fn assert_matches_method(
167        &self,
168        trait_id: TraitDeclId,
169        method_name: &TraitItemName,
170        args: &GenericArgs,
171    ) {
172        let target = &GenericsSource::Method(trait_id, method_name.clone());
173        let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
174            return;
175        };
176        let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name) else {
177            return;
178        };
179        let params = &bound_fn.params;
180        let fmt1 = self.ctx.into_fmt();
181        let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
182        let fmt = fmt2.push_binder(Cow::Borrowed(params));
183        self.assert_matches(&fmt, params, args, target);
184    }
185}
186
187impl VisitAst for CheckGenericsVisitor<'_> {
188    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
189        self.visit_stack.push(x.name());
190        x.drive(self)?; // default behavior
191        self.visit_stack.pop();
192        Continue(())
193    }
194
195    fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
196        self.binder_stack.push(binder.params.clone());
197        self.visit_inner(binder)?;
198        self.binder_stack.pop();
199        Continue(())
200    }
201    fn visit_region_binder<T: AstVisitable>(
202        &mut self,
203        binder: &RegionBinder<T>,
204    ) -> ControlFlow<Self::Break> {
205        self.binder_stack.push(GenericParams {
206            regions: binder.regions.clone(),
207            ..Default::default()
208        });
209        self.visit_inner(binder)?;
210        self.binder_stack.pop();
211        Continue(())
212    }
213
214    // Check that generics are correctly bound.
215    fn enter_region(&mut self, x: &Region) {
216        if let Region::Var(var) = x {
217            if self.binder_stack.get_var(*var).is_none() {
218                self.error(format!("Found incorrect region var: {var}"));
219            }
220        }
221    }
222    fn enter_ty_kind(&mut self, x: &TyKind) {
223        if let TyKind::TypeVar(var) = x {
224            if self.binder_stack.get_var(*var).is_none() {
225                self.error(format!("Found incorrect type var: {var}"));
226            }
227        }
228    }
229    fn enter_const_generic(&mut self, x: &ConstGeneric) {
230        if let ConstGeneric::Var(var) = x {
231            if self.binder_stack.get_var(*var).is_none() {
232                self.error(format!("Found incorrect const-generic var: {var}"));
233            }
234        }
235    }
236    fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
237        match x {
238            TraitRefKind::Clause(var) => {
239                if self.binder_stack.get_var(*var).is_none() {
240                    self.error(format!("Found incorrect clause var: {var}"));
241                }
242            }
243            TraitRefKind::BuiltinOrAuto {
244                trait_decl_ref,
245                parent_trait_refs,
246                types,
247            } => {
248                let trait_id = trait_decl_ref.skip_binder.id;
249                let target = GenericsSource::item(trait_id);
250                let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
251                    return;
252                };
253                if tdecl
254                    .item_meta
255                    .lang_item
256                    .as_deref()
257                    .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
258                {
259                    // These traits have builtin assoc types that we can't resolve.
260                    return;
261                }
262                let fmt = &self.ctx.into_fmt();
263                let args_fmt = &self.val_fmt_ctx();
264                self.zip_assert_match(
265                    &tdecl.parent_clauses,
266                    parent_trait_refs,
267                    fmt,
268                    args_fmt,
269                    "builtin trait parent clauses",
270                    &target,
271                    |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
272                );
273                let types_match = types.len() == tdecl.types.len()
274                    && tdecl
275                        .types
276                        .iter()
277                        .zip(types.iter())
278                        .all(|(dname, (iname, _, _))| dname == iname);
279                if !types_match {
280                    let target = target.with_ctx(args_fmt);
281                    let a = tdecl.types.iter().format(", ");
282                    let b = types
283                        .iter()
284                        .map(|(_, ty, _)| ty.with_ctx(args_fmt))
285                        .format(", ");
286                    self.error(format!(
287                        "Mismatched types in builtin trait ref:\
288                        \ntarget: {target}\
289                        \nexpected: [{a}]\
290                        \n     got: [{b}]"
291                    ));
292                }
293            }
294            _ => {}
295        }
296    }
297
298    // Check that generics match the parameters of the target item.
299    fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
300        match x.id {
301            TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
302            // TODO: check builtin generics.
303            TypeId::Tuple => {}
304            TypeId::Builtin(_) => {}
305        }
306    }
307    fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
308        self.assert_matches_item(x.id, &x.generics);
309    }
310    fn enter_fn_ptr(&mut self, x: &FnPtr) {
311        match x.func.as_ref() {
312            FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
313                self.assert_matches_item(*id, &x.generics)
314            }
315            // TODO: check builtin generics.
316            FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
317            FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
318                let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
319                self.assert_matches_method(trait_id, method_name, &x.generics);
320            }
321        }
322    }
323    fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
324        self.assert_matches_item(x.id, &x.generics);
325    }
326    fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
327        self.assert_matches_item(x.id, &x.generics);
328    }
329    fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
330        self.assert_matches_item(x.id, &x.generics);
331    }
332    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
333        let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
334            return;
335        };
336        // See `lift_associated_item_clauses`
337        assert!(timpl.type_clauses.is_empty());
338        assert!(tdecl.type_clauses.is_empty());
339
340        let fmt1 = self.ctx.into_fmt();
341        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
342        let args_fmt = &self.val_fmt_ctx();
343        self.zip_assert_match(
344            &tdecl.parent_clauses,
345            &timpl.parent_trait_refs,
346            &tdecl_fmt,
347            args_fmt,
348            "trait parent clauses",
349            &GenericsSource::item(timpl.impl_trait.id),
350            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
351        );
352        let types_match = timpl.types.len() == tdecl.types.len()
353            && tdecl
354                .types
355                .iter()
356                .zip(timpl.types.iter())
357                .all(|(dname, (iname, _))| dname == iname);
358        if !types_match {
359            self.error(
360                "The associated types supplied by the trait impl don't match the trait decl.",
361            )
362        }
363        let consts_match = timpl.consts.len() == tdecl.consts.len()
364            && tdecl
365                .types
366                .iter()
367                .zip(timpl.types.iter())
368                .all(|(dname, (iname, _))| dname == iname);
369        if !consts_match {
370            self.error(
371                "The associated consts supplied by the trait impl don't match the trait decl.",
372            )
373        }
374        let methods_match = timpl.methods.len() == tdecl.methods.len();
375        if !methods_match && self.phase != "after translation" {
376            let decl_methods = tdecl
377                .methods()
378                .map(|(name, _)| format!("- {name}"))
379                .join("\n");
380            let impl_methods = timpl
381                .methods()
382                .map(|(name, _)| format!("- {name}"))
383                .join("\n");
384            self.error(format!(
385                "The methods supplied by the trait impl don't match the trait decl.\n\
386                Trait methods:\n{decl_methods}\n\
387                Impl methods:\n{impl_methods}"
388            ))
389        }
390    }
391
392    // Track span for more precise error messages.
393    fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
394        let old_span = self.span;
395        self.span = st.span;
396        self.visit_inner(st)?;
397        self.span = old_span;
398        Continue(())
399    }
400    fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
401        let old_span = self.span;
402        self.span = st.span;
403        self.visit_inner(st)?;
404        self.span = old_span;
405        Continue(())
406    }
407}
408
409// The argument is a name to disambiguate the two times we run this check.
410pub struct Check(pub &'static str);
411impl TransformPass for Check {
412    fn transform_ctx(&self, ctx: &mut TransformCtx) {
413        for item in ctx.translated.all_items() {
414            // Hack: the items generated by monomorphisation have incorrect generics.
415            if item
416                .item_meta()
417                .name
418                .name
419                .last()
420                .unwrap()
421                .is_monomorphized()
422            {
423                continue;
424            }
425            let mut visitor = CheckGenericsVisitor {
426                ctx,
427                phase: self.0,
428                span: item.item_meta().span,
429                binder_stack: BindingStack::new(item.generic_params().clone()),
430                visit_stack: Default::default(),
431            };
432            let _ = item.drive(&mut visitor);
433        }
434    }
435}