charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{AstFormatter, FmtCtx, IntoFormatter},
11    pretty::FmtWithCtx,
12    transform::utils::GenericsSource,
13};
14
15use crate::transform::{TransformCtx, ctx::TransformPass};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19    ctx: &'a TransformCtx,
20    phase: &'static str,
21    /// Tracks an enclosing span for error reporting.
22    span: Span,
23    /// Track the binders seen so far.
24    // We can't keep the params by reference because the visitors don't tell us that everything
25    // we're visiting has lifetime `'a`.
26    binder_stack: BindingStack<GenericParams>,
27    /// Remember the names of the types visited up to here.
28    visit_stack: Vec<&'static str>,
29}
30
31impl VisitorWithSpan for CheckGenericsVisitor<'_> {
32    fn current_span(&mut self) -> &mut Span {
33        &mut self.span
34    }
35}
36impl VisitorWithBinderStack for CheckGenericsVisitor<'_> {
37    fn binder_stack_mut(&mut self) -> &mut BindingStack<GenericParams> {
38        &mut self.binder_stack
39    }
40}
41
42impl CheckGenericsVisitor<'_> {
43    fn check_concretization_ty_match(&self, src_ty: &Ty, tar_ty: &Ty) {
44        match (src_ty.kind(), tar_ty.kind()) {
45            (TyKind::Ref(.., src_kind), TyKind::Ref(.., tar_kind)) => {
46                assert_eq!(src_kind, tar_kind);
47            }
48            (TyKind::RawPtr(.., src_kind), TyKind::RawPtr(.., tar_kind)) => {
49                assert_eq!(src_kind, tar_kind);
50            }
51            (
52                TyKind::Adt(TypeDeclRef { id: src_id, .. }),
53                TyKind::Adt(TypeDeclRef { id: tar_id, .. }),
54            ) => {
55                assert_eq!(src_id, tar_id);
56            }
57            _ => {
58                let fmt = &self.ctx.into_fmt();
59                self.error(format!(
60                    "Invalid concretization targets: from \"{}\" to \"{}\"",
61                    src_ty.with_ctx(fmt),
62                    tar_ty.with_ctx(fmt)
63                ));
64            }
65        }
66    }
67
68    fn error(&self, message: impl Display) {
69        let msg = format!(
70            "Found inconsistent generics {}:\n{message}\n\
71            Visitor stack:\n  {}\n\
72            Binding stack (depth {}):\n  {}",
73            self.phase,
74            self.visit_stack.iter().rev().join("\n  "),
75            self.binder_stack.len(),
76            self.binder_stack
77                .iter_enumerated()
78                .map(|(i, params)| format!("{i}: {params}"))
79                .join("\n  "),
80        );
81        // This is a fatal error: the output llbc is inconsistent and should not be used.
82        self.ctx.span_err(self.span, &msg, Level::ERROR);
83    }
84
85    /// For pretty error printing. This can print values that we encounter because we track binders
86    /// properly. This doesn't have the right binders to print values we get from somewhere else
87    /// (namely the `GenericParam`s we get from elsewhere in the crate).
88    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
89        let mut fmt = self.ctx.into_fmt();
90        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
91        fmt
92    }
93
94    fn zip_assert_match<I, A, B, FmtA, FmtB>(
95        &self,
96        a: &Vector<I, A>,
97        b: &Vector<I, B>,
98        a_fmt: &FmtA,
99        b_fmt: &FmtB,
100        kind: &str,
101        target: &GenericsSource,
102        check_inner: impl Fn(&A, &B),
103    ) where
104        I: Idx,
105        FmtA: AstFormatter,
106        A: FmtWithCtx<FmtA>,
107        B: FmtWithCtx<FmtB>,
108    {
109        if a.elem_count() == b.elem_count() {
110            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
111        } else {
112            let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
113            let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
114            let target = target.with_ctx(a_fmt);
115            self.error(format!(
116                "Mismatched {kind}:\
117                \ntarget: {target}\
118                \nexpected: [{a}]\
119                \n     got: [{b}]"
120            ))
121        }
122    }
123
124    fn assert_clause_matches(
125        &self,
126        params_fmt: &FmtCtx<'_>,
127        tclause: &TraitParam,
128        tref: &TraitRef,
129    ) {
130        let clause_trait_id = tclause.trait_.skip_binder.id;
131        let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
132        if clause_trait_id != ref_trait_id {
133            let args_fmt = &self.val_fmt_ctx();
134            let tclause = tclause.with_ctx(params_fmt);
135            let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
136            let tref = tref.with_ctx(args_fmt);
137            self.error(format!(
138                "Mismatched trait clause:\
139                \nexpected: {tclause}\
140                \n     got: {tref}: {tref_pred}"
141            ));
142        }
143    }
144
145    fn assert_matches(
146        &self,
147        params_fmt: &FmtCtx<'_>,
148        params: &GenericParams,
149        args: &GenericArgs,
150        target: &GenericsSource,
151    ) {
152        let args_fmt = &self.val_fmt_ctx();
153        self.zip_assert_match(
154            &params.regions,
155            &args.regions,
156            params_fmt,
157            args_fmt,
158            "regions",
159            target,
160            |_, _| {},
161        );
162        self.zip_assert_match(
163            &params.types,
164            &args.types,
165            params_fmt,
166            args_fmt,
167            "type generics",
168            target,
169            |_, _| {},
170        );
171        self.zip_assert_match(
172            &params.const_generics,
173            &args.const_generics,
174            params_fmt,
175            args_fmt,
176            "const generics",
177            target,
178            |_, _| {},
179        );
180        self.zip_assert_match(
181            &params.trait_clauses,
182            &args.trait_refs,
183            params_fmt,
184            args_fmt,
185            "trait clauses",
186            target,
187            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
188        );
189    }
190
191    fn assert_matches_item(&self, id: impl Into<ItemId>, args: &GenericArgs) {
192        let id = id.into();
193        let Some(item) = self.ctx.translated.get_item(id) else {
194            return;
195        };
196        let params = item.generic_params();
197        let fmt1 = self.ctx.into_fmt();
198        let fmt = fmt1.push_binder(Cow::Borrowed(params));
199        self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
200    }
201
202    fn assert_matches_method(
203        &self,
204        trait_id: TraitDeclId,
205        method_name: &TraitItemName,
206        args: &GenericArgs,
207    ) {
208        let target = &GenericsSource::Method(trait_id, method_name.clone());
209        let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
210            return;
211        };
212        let Some(bound_fn) = trait_decl.methods().find(|m| m.name() == method_name) else {
213            return;
214        };
215        let params = &bound_fn.params;
216        let fmt1 = self.ctx.into_fmt();
217        let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
218        let fmt = fmt2.push_binder(Cow::Borrowed(params));
219        self.assert_matches(&fmt, params, args, target);
220    }
221}
222
223impl VisitAst for CheckGenericsVisitor<'_> {
224    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
225        self.visit_stack.push(x.name());
226        VisitWithSpan::new(VisitWithBinderStack::new(self)).visit(x)?;
227        self.visit_stack.pop();
228        Continue(())
229    }
230
231    // Check that generics are correctly bound.
232    fn enter_region(&mut self, x: &Region) {
233        if let Region::Var(var) = x {
234            if self.binder_stack.get_var(*var).is_none() {
235                self.error(format!("Found incorrect region var: {var}"));
236            }
237        }
238    }
239    fn enter_ty_kind(&mut self, x: &TyKind) {
240        if let TyKind::TypeVar(var) = x {
241            if self.binder_stack.get_var(*var).is_none() {
242                self.error(format!("Found incorrect type var: {var}"));
243            }
244        }
245    }
246    fn enter_const_generic(&mut self, x: &ConstGeneric) {
247        if let ConstGeneric::Var(var) = x {
248            if self.binder_stack.get_var(*var).is_none() {
249                self.error(format!("Found incorrect const-generic var: {var}"));
250            }
251        }
252    }
253    fn enter_trait_ref(&mut self, x: &TraitRef) {
254        match &x.kind {
255            TraitRefKind::Clause(var) => {
256                if self.binder_stack.get_var(*var).is_none() {
257                    self.error(format!("Found incorrect clause var: {var}"));
258                }
259            }
260            TraitRefKind::BuiltinOrAuto {
261                parent_trait_refs,
262                types,
263            } => {
264                let trait_id = x.trait_decl_ref.skip_binder.id;
265                let target = GenericsSource::item(trait_id);
266                let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
267                    return;
268                };
269                if tdecl
270                    .item_meta
271                    .lang_item
272                    .as_deref()
273                    .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
274                {
275                    // These traits have builtin assoc types that we can't resolve.
276                    return;
277                }
278                let fmt = &self.ctx.into_fmt();
279                let args_fmt = &self.val_fmt_ctx();
280                self.zip_assert_match(
281                    &tdecl.implied_clauses,
282                    parent_trait_refs,
283                    fmt,
284                    args_fmt,
285                    "builtin trait parent clauses",
286                    &target,
287                    |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
288                );
289                let types_match = types.len() == tdecl.types.len()
290                    && tdecl
291                        .types
292                        .iter()
293                        .zip(types.iter())
294                        .all(|(dty, (iname, _))| dty.name() == iname);
295                if !types_match {
296                    let target = target.with_ctx(args_fmt);
297                    let a = tdecl.types.iter().map(|t| t.name()).format(", ");
298                    let b = types
299                        .iter()
300                        .map(|(_, assoc_ty)| assoc_ty.value.with_ctx(args_fmt))
301                        .format(", ");
302                    self.error(format!(
303                        "Mismatched types in builtin trait ref:\
304                        \ntarget: {target}\
305                        \nexpected: [{a}]\
306                        \n     got: [{b}]"
307                    ));
308                }
309            }
310            _ => {}
311        }
312    }
313
314    // Check that generics match the parameters of the target item.
315    fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
316        match x.id {
317            TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
318            // TODO: check builtin generics.
319            TypeId::Tuple => {}
320            TypeId::Builtin(_) => {}
321        }
322    }
323    fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
324        self.assert_matches_item(x.id, &x.generics);
325    }
326    fn enter_fn_ptr(&mut self, x: &FnPtr) {
327        match x.kind.as_ref() {
328            FnPtrKind::Fun(FunId::Regular(id)) => self.assert_matches_item(*id, &x.generics),
329            // TODO: check builtin generics.
330            FnPtrKind::Fun(FunId::Builtin(_)) => {}
331            FnPtrKind::Trait(trait_ref, method_name, _) => {
332                let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
333                self.assert_matches_method(trait_id, method_name, &x.generics);
334            }
335        }
336    }
337    fn visit_rvalue(&mut self, x: &Rvalue) -> ::std::ops::ControlFlow<Self::Break> {
338        match x {
339            Rvalue::UnaryOp(UnOp::Cast(CastKind::Concretize(src, tar)), _) => {
340                self.check_concretization_ty_match(src, tar);
341            }
342            _ => {}
343        }
344        Continue(())
345    }
346    fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
347        self.assert_matches_item(x.id, &x.generics);
348    }
349    fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
350        self.assert_matches_item(x.id, &x.generics);
351    }
352    fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
353        self.assert_matches_item(x.id, &x.generics);
354    }
355    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
356        let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
357            return;
358        };
359
360        let fmt1 = self.ctx.into_fmt();
361        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
362        let args_fmt = &self.val_fmt_ctx();
363        self.zip_assert_match(
364            &tdecl.implied_clauses,
365            &timpl.implied_trait_refs,
366            &tdecl_fmt,
367            args_fmt,
368            "trait parent clauses",
369            &GenericsSource::item(timpl.impl_trait.id),
370            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
371        );
372        // TODO: check type clauses
373        let types_match = timpl.types.len() == tdecl.types.len()
374            && tdecl
375                .types
376                .iter()
377                .zip(timpl.types.iter())
378                .all(|(dty, (iname, _))| dty.name() == iname);
379        if !types_match {
380            self.error(
381                "The associated types supplied by the trait impl don't match the trait decl.",
382            )
383        }
384        let consts_match = timpl.consts.len() == tdecl.consts.len()
385            && tdecl
386                .consts
387                .iter()
388                .zip(timpl.consts.iter())
389                .all(|(dconst, (iname, _))| &dconst.name == iname);
390        if !consts_match {
391            self.error(
392                "The associated consts supplied by the trait impl don't match the trait decl.",
393            )
394        }
395        let methods_match = timpl.methods.len() == tdecl.methods.len();
396        if !methods_match && self.phase != "after translation" {
397            let decl_methods = tdecl
398                .methods()
399                .map(|m| format!("- {}", m.name()))
400                .join("\n");
401            let impl_methods = timpl
402                .methods()
403                .map(|(name, _)| format!("- {name}"))
404                .join("\n");
405            self.error(format!(
406                "The methods supplied by the trait impl don't match the trait decl.\n\
407                Trait methods:\n{decl_methods}\n\
408                Impl methods:\n{impl_methods}"
409            ))
410        }
411    }
412}
413
414// The argument is a name to disambiguate the two times we run this check.
415pub struct Check(pub &'static str);
416impl TransformPass for Check {
417    fn transform_ctx(&self, ctx: &mut TransformCtx) {
418        for item in ctx.translated.all_items() {
419            // Hack: the items generated by monomorphisation have incorrect generics.
420            // TODO(dyn): remove once we support dyn in mono and we can remove the manual mono
421            // pass.
422            if item
423                .item_meta()
424                .name
425                .name
426                .last()
427                .unwrap()
428                .is_monomorphized()
429            {
430                continue;
431            }
432            let mut visitor = CheckGenericsVisitor {
433                ctx,
434                phase: self.0,
435                span: Span::dummy(),
436                binder_stack: BindingStack::empty(),
437                visit_stack: Default::default(),
438            };
439            let _ = item.drive(&mut visitor);
440        }
441    }
442}