charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{AstFormatter, FmtCtx, IntoFormatter},
11    pretty::FmtWithCtx,
12    transform::utils::GenericsSource,
13};
14
15use super::{ctx::TransformPass, TransformCtx};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19    ctx: &'a TransformCtx,
20    phase: &'static str,
21    /// Tracks an enclosing span for error reporting.
22    span: Span,
23    /// Track the binders seen so far.
24    // We can't keep the params by reference because the visitors don't tell us that everything
25    // we're visiting has lifetime `'a`.
26    binder_stack: BindingStack<GenericParams>,
27    /// Remember the names of the types visited up to here.
28    visit_stack: Vec<&'static str>,
29}
30
31impl CheckGenericsVisitor<'_> {
32    fn error(&self, message: impl Display) {
33        let msg = format!(
34            "Found inconsistent generics {}:\n{message}\n\
35            Visitor stack:\n  {}\n\
36            Binding stack (depth {}):\n  {}",
37            self.phase,
38            self.visit_stack.iter().rev().join("\n  "),
39            self.binder_stack.len(),
40            self.binder_stack
41                .iter_enumerated()
42                .map(|(i, params)| format!("{i}: {params}"))
43                .join("\n  "),
44        );
45        // This is a fatal error: the output llbc is inconsistent and should not be used.
46        self.ctx.span_err(self.span, &msg, Level::ERROR);
47    }
48
49    /// For pretty error printing. This can print values that we encounter because we track binders
50    /// properly. This doesn't have the right binders to print values we get from somewhere else
51    /// (namely the `GenericParam`s we get from elsewhere in the crate).
52    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
53        let mut fmt = self.ctx.into_fmt();
54        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
55        fmt
56    }
57
58    fn zip_assert_match<I, A, B, FmtA, FmtB>(
59        &self,
60        a: &Vector<I, A>,
61        b: &Vector<I, B>,
62        a_fmt: &FmtA,
63        b_fmt: &FmtB,
64        kind: &str,
65        target: &GenericsSource,
66        check_inner: impl Fn(&A, &B),
67    ) where
68        I: Idx,
69        FmtA: AstFormatter,
70        A: FmtWithCtx<FmtA>,
71        B: FmtWithCtx<FmtB>,
72    {
73        if a.elem_count() == b.elem_count() {
74            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
75        } else {
76            let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
77            let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
78            let target = target.with_ctx(a_fmt);
79            self.error(format!(
80                "Mismatched {kind}:\
81                \ntarget: {target}\
82                \nexpected: [{a}]\
83                \n     got: [{b}]"
84            ))
85        }
86    }
87
88    fn assert_clause_matches(
89        &self,
90        params_fmt: &FmtCtx<'_>,
91        tclause: &TraitClause,
92        tref: &TraitRef,
93    ) {
94        let clause_trait_id = tclause.trait_.skip_binder.id;
95        let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
96        if clause_trait_id != ref_trait_id {
97            let args_fmt = &self.val_fmt_ctx();
98            let tclause = tclause.with_ctx(params_fmt);
99            let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
100            let tref = tref.with_ctx(args_fmt);
101            self.error(format!(
102                "Mismatched trait clause:\
103                \nexpected: {tclause}\
104                \n     got: {tref}: {tref_pred}"
105            ));
106        }
107    }
108
109    fn assert_matches(
110        &self,
111        params_fmt: &FmtCtx<'_>,
112        params: &GenericParams,
113        args: &GenericArgs,
114        target: &GenericsSource,
115    ) {
116        let args_fmt = &self.val_fmt_ctx();
117        self.zip_assert_match(
118            &params.regions,
119            &args.regions,
120            params_fmt,
121            args_fmt,
122            "regions",
123            target,
124            |_, _| {},
125        );
126        self.zip_assert_match(
127            &params.types,
128            &args.types,
129            params_fmt,
130            args_fmt,
131            "type generics",
132            target,
133            |_, _| {},
134        );
135        self.zip_assert_match(
136            &params.const_generics,
137            &args.const_generics,
138            params_fmt,
139            args_fmt,
140            "const generics",
141            target,
142            |_, _| {},
143        );
144        self.zip_assert_match(
145            &params.trait_clauses,
146            &args.trait_refs,
147            params_fmt,
148            args_fmt,
149            "trait clauses",
150            target,
151            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
152        );
153    }
154
155    fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
156        let id = id.into();
157        let Some(item) = self.ctx.translated.get_item(id) else {
158            return;
159        };
160        let params = item.generic_params();
161        let fmt1 = self.ctx.into_fmt();
162        let fmt = fmt1.push_binder(Cow::Borrowed(params));
163        self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
164    }
165
166    fn assert_matches_method(
167        &self,
168        trait_id: TraitDeclId,
169        method_name: &TraitItemName,
170        args: &GenericArgs,
171    ) {
172        let target = &GenericsSource::Method(trait_id, method_name.clone());
173        let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
174            return;
175        };
176        let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name) else {
177            return;
178        };
179        let params = &bound_fn.params;
180        let fmt1 = self.ctx.into_fmt();
181        let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
182        let fmt = fmt2.push_binder(Cow::Borrowed(params));
183        self.assert_matches(&fmt, params, args, target);
184    }
185}
186
187impl VisitAst for CheckGenericsVisitor<'_> {
188    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
189        self.visit_stack.push(x.name());
190        x.drive(self)?; // default behavior
191        self.visit_stack.pop();
192        Continue(())
193    }
194
195    fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
196        self.binder_stack.push(binder.params.clone());
197        self.visit_inner(binder)?;
198        self.binder_stack.pop();
199        Continue(())
200    }
201    fn visit_region_binder<T: AstVisitable>(
202        &mut self,
203        binder: &RegionBinder<T>,
204    ) -> ControlFlow<Self::Break> {
205        self.binder_stack.push(GenericParams {
206            regions: binder.regions.clone(),
207            ..Default::default()
208        });
209        self.visit_inner(binder)?;
210        self.binder_stack.pop();
211        Continue(())
212    }
213
214    // Check that generics are correctly bound.
215    fn enter_region(&mut self, x: &Region) {
216        if let Region::Var(var) = x {
217            if self.binder_stack.get_var(*var).is_none() {
218                self.error(format!("Found incorrect region var: {var}"));
219            }
220        }
221    }
222    fn enter_ty_kind(&mut self, x: &TyKind) {
223        if let TyKind::TypeVar(var) = x {
224            if self.binder_stack.get_var(*var).is_none() {
225                self.error(format!("Found incorrect type var: {var}"));
226            }
227        }
228    }
229    fn enter_const_generic(&mut self, x: &ConstGeneric) {
230        if let ConstGeneric::Var(var) = x {
231            if self.binder_stack.get_var(*var).is_none() {
232                self.error(format!("Found incorrect const-generic var: {var}"));
233            }
234        }
235    }
236    fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
237        if let TraitRefKind::Clause(var) = x {
238            if self.binder_stack.get_var(*var).is_none() {
239                self.error(format!("Found incorrect clause var: {var}"));
240            }
241        }
242    }
243
244    // Check that generics match the parameters of the target item.
245    fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
246        match x.id {
247            TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
248            // TODO: check builtin generics.
249            TypeId::Tuple => {}
250            TypeId::Builtin(_) => {}
251        }
252    }
253    fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
254        self.assert_matches_item(x.id, &x.generics);
255    }
256    fn enter_fn_ptr(&mut self, x: &FnPtr) {
257        match x.func.as_ref() {
258            FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
259                self.assert_matches_item(*id, &x.generics)
260            }
261            // TODO: check builtin generics.
262            FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
263            FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
264                let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
265                self.assert_matches_method(trait_id, method_name, &x.generics);
266            }
267        }
268    }
269    fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
270        self.assert_matches_item(x.id, &x.generics);
271    }
272    fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
273        self.assert_matches_item(x.id, &x.generics);
274    }
275    fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
276        self.assert_matches_item(x.id, &x.generics);
277    }
278    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
279        let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
280            return;
281        };
282        // See `lift_associated_item_clauses`
283        assert!(timpl.type_clauses.is_empty());
284        assert!(tdecl.type_clauses.is_empty());
285
286        let fmt1 = self.ctx.into_fmt();
287        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
288        let args_fmt = &self.val_fmt_ctx();
289        self.zip_assert_match(
290            &tdecl.parent_clauses,
291            &timpl.parent_trait_refs,
292            &tdecl_fmt,
293            args_fmt,
294            "trait parent clauses",
295            &GenericsSource::item(timpl.impl_trait.id),
296            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
297        );
298        let types_match = timpl.types.len() == tdecl.types.len()
299            && tdecl
300                .types
301                .iter()
302                .zip(timpl.types.iter())
303                .all(|(dname, (iname, _))| dname == iname);
304        if !types_match {
305            self.error(
306                "The associated types supplied by the trait impl don't match the trait decl.",
307            )
308        }
309        let consts_match = timpl.consts.len() == tdecl.consts.len()
310            && tdecl
311                .types
312                .iter()
313                .zip(timpl.types.iter())
314                .all(|(dname, (iname, _))| dname == iname);
315        if !consts_match {
316            self.error(
317                "The associated consts supplied by the trait impl don't match the trait decl.",
318            )
319        }
320        let methods_match = timpl.methods.len() == tdecl.methods.len();
321        if !methods_match && self.phase != "after translation" {
322            let decl_methods = tdecl
323                .methods()
324                .map(|(name, _)| format!("- {name}"))
325                .join("\n");
326            let impl_methods = timpl
327                .methods()
328                .map(|(name, _)| format!("- {name}"))
329                .join("\n");
330            self.error(format!(
331                "The methods supplied by the trait impl don't match the trait decl.\n\
332                Trait methods:\n{decl_methods}\n\
333                Impl methods:\n{impl_methods}"
334            ))
335        }
336    }
337
338    // Track span for more precise error messages.
339    fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
340        let old_span = self.span;
341        self.span = st.span;
342        self.visit_inner(st)?;
343        self.span = old_span;
344        Continue(())
345    }
346    fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
347        let old_span = self.span;
348        self.span = st.span;
349        self.visit_inner(st)?;
350        self.span = old_span;
351        Continue(())
352    }
353}
354
355// The argument is a name to disambiguate the two times we run this check.
356pub struct Check(pub &'static str);
357impl TransformPass for Check {
358    fn transform_ctx(&self, ctx: &mut TransformCtx) {
359        for item in ctx.translated.all_items() {
360            // Hack: the items generated by monomorphisation have incorrect generics.
361            if item
362                .item_meta()
363                .name
364                .name
365                .last()
366                .unwrap()
367                .is_monomorphized()
368            {
369                continue;
370            }
371            let mut visitor = CheckGenericsVisitor {
372                ctx,
373                phase: self.0,
374                span: item.item_meta().span,
375                binder_stack: BindingStack::new(item.generic_params().clone()),
376                visit_stack: Default::default(),
377            };
378            let _ = item.drive(&mut visitor);
379        }
380    }
381}