charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{AstFormatter, FmtCtx, IntoFormatter},
11    pretty::FmtWithCtx,
12    transform::utils::GenericsSource,
13};
14
15use super::{TransformCtx, ctx::TransformPass};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19    ctx: &'a TransformCtx,
20    phase: &'static str,
21    /// Tracks an enclosing span for error reporting.
22    span: Span,
23    /// Track the binders seen so far.
24    // We can't keep the params by reference because the visitors don't tell us that everything
25    // we're visiting has lifetime `'a`.
26    binder_stack: BindingStack<GenericParams>,
27    /// Remember the names of the types visited up to here.
28    visit_stack: Vec<&'static str>,
29}
30
31impl VisitorWithSpan for CheckGenericsVisitor<'_> {
32    fn current_span(&mut self) -> &mut Span {
33        &mut self.span
34    }
35}
36impl VisitorWithBinderStack for CheckGenericsVisitor<'_> {
37    fn binder_stack_mut(&mut self) -> &mut BindingStack<GenericParams> {
38        &mut self.binder_stack
39    }
40}
41
42impl CheckGenericsVisitor<'_> {
43    fn error(&self, message: impl Display) {
44        let msg = format!(
45            "Found inconsistent generics {}:\n{message}\n\
46            Visitor stack:\n  {}\n\
47            Binding stack (depth {}):\n  {}",
48            self.phase,
49            self.visit_stack.iter().rev().join("\n  "),
50            self.binder_stack.len(),
51            self.binder_stack
52                .iter_enumerated()
53                .map(|(i, params)| format!("{i}: {params}"))
54                .join("\n  "),
55        );
56        // This is a fatal error: the output llbc is inconsistent and should not be used.
57        self.ctx.span_err(self.span, &msg, Level::ERROR);
58    }
59
60    /// For pretty error printing. This can print values that we encounter because we track binders
61    /// properly. This doesn't have the right binders to print values we get from somewhere else
62    /// (namely the `GenericParam`s we get from elsewhere in the crate).
63    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
64        let mut fmt = self.ctx.into_fmt();
65        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
66        fmt
67    }
68
69    fn zip_assert_match<I, A, B, FmtA, FmtB>(
70        &self,
71        a: &Vector<I, A>,
72        b: &Vector<I, B>,
73        a_fmt: &FmtA,
74        b_fmt: &FmtB,
75        kind: &str,
76        target: &GenericsSource,
77        check_inner: impl Fn(&A, &B),
78    ) where
79        I: Idx,
80        FmtA: AstFormatter,
81        A: FmtWithCtx<FmtA>,
82        B: FmtWithCtx<FmtB>,
83    {
84        if a.elem_count() == b.elem_count() {
85            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
86        } else {
87            let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
88            let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
89            let target = target.with_ctx(a_fmt);
90            self.error(format!(
91                "Mismatched {kind}:\
92                \ntarget: {target}\
93                \nexpected: [{a}]\
94                \n     got: [{b}]"
95            ))
96        }
97    }
98
99    fn assert_clause_matches(
100        &self,
101        params_fmt: &FmtCtx<'_>,
102        tclause: &TraitClause,
103        tref: &TraitRef,
104    ) {
105        let clause_trait_id = tclause.trait_.skip_binder.id;
106        let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
107        if clause_trait_id != ref_trait_id {
108            let args_fmt = &self.val_fmt_ctx();
109            let tclause = tclause.with_ctx(params_fmt);
110            let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
111            let tref = tref.with_ctx(args_fmt);
112            self.error(format!(
113                "Mismatched trait clause:\
114                \nexpected: {tclause}\
115                \n     got: {tref}: {tref_pred}"
116            ));
117        }
118    }
119
120    fn assert_matches(
121        &self,
122        params_fmt: &FmtCtx<'_>,
123        params: &GenericParams,
124        args: &GenericArgs,
125        target: &GenericsSource,
126    ) {
127        let args_fmt = &self.val_fmt_ctx();
128        self.zip_assert_match(
129            &params.regions,
130            &args.regions,
131            params_fmt,
132            args_fmt,
133            "regions",
134            target,
135            |_, _| {},
136        );
137        self.zip_assert_match(
138            &params.types,
139            &args.types,
140            params_fmt,
141            args_fmt,
142            "type generics",
143            target,
144            |_, _| {},
145        );
146        self.zip_assert_match(
147            &params.const_generics,
148            &args.const_generics,
149            params_fmt,
150            args_fmt,
151            "const generics",
152            target,
153            |_, _| {},
154        );
155        self.zip_assert_match(
156            &params.trait_clauses,
157            &args.trait_refs,
158            params_fmt,
159            args_fmt,
160            "trait clauses",
161            target,
162            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
163        );
164    }
165
166    fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
167        let id = id.into();
168        let Some(item) = self.ctx.translated.get_item(id) else {
169            return;
170        };
171        let params = item.generic_params();
172        let fmt1 = self.ctx.into_fmt();
173        let fmt = fmt1.push_binder(Cow::Borrowed(params));
174        self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
175    }
176
177    fn assert_matches_method(
178        &self,
179        trait_id: TraitDeclId,
180        method_name: &TraitItemName,
181        args: &GenericArgs,
182    ) {
183        let target = &GenericsSource::Method(trait_id, method_name.clone());
184        let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
185            return;
186        };
187        let Some(bound_fn) = trait_decl.methods().find(|m| m.name() == method_name) else {
188            return;
189        };
190        let params = &bound_fn.params;
191        let fmt1 = self.ctx.into_fmt();
192        let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
193        let fmt = fmt2.push_binder(Cow::Borrowed(params));
194        self.assert_matches(&fmt, params, args, target);
195    }
196}
197
198impl VisitAst for CheckGenericsVisitor<'_> {
199    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
200        self.visit_stack.push(x.name());
201        VisitWithSpan::new(VisitWithBinderStack::new(self)).visit(x)?;
202        self.visit_stack.pop();
203        Continue(())
204    }
205
206    // Check that generics are correctly bound.
207    fn enter_region(&mut self, x: &Region) {
208        if let Region::Var(var) = x {
209            if self.binder_stack.get_var(*var).is_none() {
210                self.error(format!("Found incorrect region var: {var}"));
211            }
212        }
213    }
214    fn enter_ty_kind(&mut self, x: &TyKind) {
215        if let TyKind::TypeVar(var) = x {
216            if self.binder_stack.get_var(*var).is_none() {
217                self.error(format!("Found incorrect type var: {var}"));
218            }
219        }
220    }
221    fn enter_const_generic(&mut self, x: &ConstGeneric) {
222        if let ConstGeneric::Var(var) = x {
223            if self.binder_stack.get_var(*var).is_none() {
224                self.error(format!("Found incorrect const-generic var: {var}"));
225            }
226        }
227    }
228    fn enter_trait_ref(&mut self, x: &TraitRef) {
229        match &x.kind {
230            TraitRefKind::Clause(var) => {
231                if self.binder_stack.get_var(*var).is_none() {
232                    self.error(format!("Found incorrect clause var: {var}"));
233                }
234            }
235            TraitRefKind::BuiltinOrAuto {
236                parent_trait_refs,
237                types,
238            } => {
239                let trait_id = x.trait_decl_ref.skip_binder.id;
240                let target = GenericsSource::item(trait_id);
241                let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
242                    return;
243                };
244                if tdecl
245                    .item_meta
246                    .lang_item
247                    .as_deref()
248                    .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
249                {
250                    // These traits have builtin assoc types that we can't resolve.
251                    return;
252                }
253                let fmt = &self.ctx.into_fmt();
254                let args_fmt = &self.val_fmt_ctx();
255                self.zip_assert_match(
256                    &tdecl.parent_clauses,
257                    parent_trait_refs,
258                    fmt,
259                    args_fmt,
260                    "builtin trait parent clauses",
261                    &target,
262                    |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
263                );
264                let types_match = types.len() == tdecl.types.len()
265                    && tdecl
266                        .types
267                        .iter()
268                        .zip(types.iter())
269                        .all(|(dty, (iname, _))| dty.name() == iname);
270                if !types_match {
271                    let target = target.with_ctx(args_fmt);
272                    let a = tdecl.types.iter().map(|t| t.name()).format(", ");
273                    let b = types
274                        .iter()
275                        .map(|(_, assoc_ty)| assoc_ty.value.with_ctx(args_fmt))
276                        .format(", ");
277                    self.error(format!(
278                        "Mismatched types in builtin trait ref:\
279                        \ntarget: {target}\
280                        \nexpected: [{a}]\
281                        \n     got: [{b}]"
282                    ));
283                }
284            }
285            _ => {}
286        }
287    }
288
289    // Check that generics match the parameters of the target item.
290    fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
291        match x.id {
292            TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
293            // TODO: check builtin generics.
294            TypeId::Tuple => {}
295            TypeId::Builtin(_) => {}
296        }
297    }
298    fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
299        self.assert_matches_item(x.id, &x.generics);
300    }
301    fn enter_fn_ptr(&mut self, x: &FnPtr) {
302        match x.func.as_ref() {
303            FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
304                self.assert_matches_item(*id, &x.generics)
305            }
306            // TODO: check builtin generics.
307            FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
308            FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
309                let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
310                self.assert_matches_method(trait_id, method_name, &x.generics);
311            }
312        }
313    }
314    fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
315        self.assert_matches_item(x.id, &x.generics);
316    }
317    fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
318        self.assert_matches_item(x.id, &x.generics);
319    }
320    fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
321        self.assert_matches_item(x.id, &x.generics);
322    }
323    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
324        let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
325            return;
326        };
327
328        let fmt1 = self.ctx.into_fmt();
329        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
330        let args_fmt = &self.val_fmt_ctx();
331        self.zip_assert_match(
332            &tdecl.parent_clauses,
333            &timpl.parent_trait_refs,
334            &tdecl_fmt,
335            args_fmt,
336            "trait parent clauses",
337            &GenericsSource::item(timpl.impl_trait.id),
338            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
339        );
340        // TODO: check type clauses
341        let types_match = timpl.types.len() == tdecl.types.len()
342            && tdecl
343                .types
344                .iter()
345                .zip(timpl.types.iter())
346                .all(|(dty, (iname, _))| dty.name() == iname);
347        if !types_match {
348            self.error(
349                "The associated types supplied by the trait impl don't match the trait decl.",
350            )
351        }
352        let consts_match = timpl.consts.len() == tdecl.consts.len()
353            && tdecl
354                .consts
355                .iter()
356                .zip(timpl.consts.iter())
357                .all(|(dconst, (iname, _))| &dconst.name == iname);
358        if !consts_match {
359            self.error(
360                "The associated consts supplied by the trait impl don't match the trait decl.",
361            )
362        }
363        let methods_match = timpl.methods.len() == tdecl.methods.len();
364        if !methods_match && self.phase != "after translation" {
365            let decl_methods = tdecl
366                .methods()
367                .map(|m| format!("- {}", m.name()))
368                .join("\n");
369            let impl_methods = timpl
370                .methods()
371                .map(|(name, _)| format!("- {name}"))
372                .join("\n");
373            self.error(format!(
374                "The methods supplied by the trait impl don't match the trait decl.\n\
375                Trait methods:\n{decl_methods}\n\
376                Impl methods:\n{impl_methods}"
377            ))
378        }
379    }
380}
381
382// The argument is a name to disambiguate the two times we run this check.
383pub struct Check(pub &'static str);
384impl TransformPass for Check {
385    fn transform_ctx(&self, ctx: &mut TransformCtx) {
386        for item in ctx.translated.all_items() {
387            // Hack: the items generated by monomorphisation have incorrect generics.
388            // TODO(dyn): remove once we support dyn in mono and we can remove the manual mono
389            // pass.
390            if item
391                .item_meta()
392                .name
393                .name
394                .last()
395                .unwrap()
396                .is_monomorphized()
397            {
398                continue;
399            }
400            let mut visitor = CheckGenericsVisitor {
401                ctx,
402                phase: self.0,
403                span: Span::dummy(),
404                binder_stack: BindingStack::empty(),
405                visit_stack: Default::default(),
406            };
407            let _ = item.drive(&mut visitor);
408        }
409    }
410}