charon_lib/transform/
check_generics.rs

1//! Check that all supplied generic types match the corresponding generic parameters.
2use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8    ast::*,
9    errors::Level,
10    formatter::{AstFormatter, FmtCtx, IntoFormatter},
11    pretty::FmtWithCtx,
12    transform::utils::GenericsSource,
13};
14
15use super::{TransformCtx, ctx::TransformPass};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19    ctx: &'a TransformCtx,
20    phase: &'static str,
21    /// Tracks an enclosing span for error reporting.
22    span: Span,
23    /// Track the binders seen so far.
24    // We can't keep the params by reference because the visitors don't tell us that everything
25    // we're visiting has lifetime `'a`.
26    binder_stack: BindingStack<GenericParams>,
27    /// Remember the names of the types visited up to here.
28    visit_stack: Vec<&'static str>,
29}
30
31impl VisitorWithSpan for CheckGenericsVisitor<'_> {
32    fn current_span(&mut self) -> &mut Span {
33        &mut self.span
34    }
35}
36impl VisitorWithBinderStack for CheckGenericsVisitor<'_> {
37    fn binder_stack_mut(&mut self) -> &mut BindingStack<GenericParams> {
38        &mut self.binder_stack
39    }
40}
41
42impl CheckGenericsVisitor<'_> {
43    fn error(&self, message: impl Display) {
44        let msg = format!(
45            "Found inconsistent generics {}:\n{message}\n\
46            Visitor stack:\n  {}\n\
47            Binding stack (depth {}):\n  {}",
48            self.phase,
49            self.visit_stack.iter().rev().join("\n  "),
50            self.binder_stack.len(),
51            self.binder_stack
52                .iter_enumerated()
53                .map(|(i, params)| format!("{i}: {params}"))
54                .join("\n  "),
55        );
56        // This is a fatal error: the output llbc is inconsistent and should not be used.
57        self.ctx.span_err(self.span, &msg, Level::ERROR);
58    }
59
60    /// For pretty error printing. This can print values that we encounter because we track binders
61    /// properly. This doesn't have the right binders to print values we get from somewhere else
62    /// (namely the `GenericParam`s we get from elsewhere in the crate).
63    fn val_fmt_ctx(&self) -> FmtCtx<'_> {
64        let mut fmt = self.ctx.into_fmt();
65        fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
66        fmt
67    }
68
69    fn zip_assert_match<I, A, B, FmtA, FmtB>(
70        &self,
71        a: &Vector<I, A>,
72        b: &Vector<I, B>,
73        a_fmt: &FmtA,
74        b_fmt: &FmtB,
75        kind: &str,
76        target: &GenericsSource,
77        check_inner: impl Fn(&A, &B),
78    ) where
79        I: Idx,
80        FmtA: AstFormatter,
81        A: FmtWithCtx<FmtA>,
82        B: FmtWithCtx<FmtB>,
83    {
84        if a.elem_count() == b.elem_count() {
85            a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
86        } else {
87            let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
88            let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
89            let target = target.with_ctx(a_fmt);
90            self.error(format!(
91                "Mismatched {kind}:\
92                \ntarget: {target}\
93                \nexpected: [{a}]\
94                \n     got: [{b}]"
95            ))
96        }
97    }
98
99    fn assert_clause_matches(
100        &self,
101        params_fmt: &FmtCtx<'_>,
102        tclause: &TraitClause,
103        tref: &TraitRef,
104    ) {
105        let clause_trait_id = tclause.trait_.skip_binder.id;
106        let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
107        if clause_trait_id != ref_trait_id {
108            let args_fmt = &self.val_fmt_ctx();
109            let tclause = tclause.with_ctx(params_fmt);
110            let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
111            let tref = tref.with_ctx(args_fmt);
112            self.error(format!(
113                "Mismatched trait clause:\
114                \nexpected: {tclause}\
115                \n     got: {tref}: {tref_pred}"
116            ));
117        }
118    }
119
120    fn assert_matches(
121        &self,
122        params_fmt: &FmtCtx<'_>,
123        params: &GenericParams,
124        args: &GenericArgs,
125        target: &GenericsSource,
126    ) {
127        let args_fmt = &self.val_fmt_ctx();
128        self.zip_assert_match(
129            &params.regions,
130            &args.regions,
131            params_fmt,
132            args_fmt,
133            "regions",
134            target,
135            |_, _| {},
136        );
137        self.zip_assert_match(
138            &params.types,
139            &args.types,
140            params_fmt,
141            args_fmt,
142            "type generics",
143            target,
144            |_, _| {},
145        );
146        self.zip_assert_match(
147            &params.const_generics,
148            &args.const_generics,
149            params_fmt,
150            args_fmt,
151            "const generics",
152            target,
153            |_, _| {},
154        );
155        self.zip_assert_match(
156            &params.trait_clauses,
157            &args.trait_refs,
158            params_fmt,
159            args_fmt,
160            "trait clauses",
161            target,
162            |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
163        );
164    }
165
166    fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
167        let id = id.into();
168        let Some(item) = self.ctx.translated.get_item(id) else {
169            return;
170        };
171        let params = item.generic_params();
172        let fmt1 = self.ctx.into_fmt();
173        let fmt = fmt1.push_binder(Cow::Borrowed(params));
174        self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
175    }
176
177    fn assert_matches_method(
178        &self,
179        trait_id: TraitDeclId,
180        method_name: &TraitItemName,
181        args: &GenericArgs,
182    ) {
183        let target = &GenericsSource::Method(trait_id, method_name.clone());
184        let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
185            return;
186        };
187        let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name) else {
188            return;
189        };
190        let params = &bound_fn.params;
191        let fmt1 = self.ctx.into_fmt();
192        let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
193        let fmt = fmt2.push_binder(Cow::Borrowed(params));
194        self.assert_matches(&fmt, params, args, target);
195    }
196}
197
198impl VisitAst for CheckGenericsVisitor<'_> {
199    fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
200        self.visit_stack.push(x.name());
201        VisitWithSpan::new(VisitWithBinderStack::new(self)).visit(x)?;
202        self.visit_stack.pop();
203        Continue(())
204    }
205
206    // Check that generics are correctly bound.
207    fn enter_region(&mut self, x: &Region) {
208        if let Region::Var(var) = x {
209            if self.binder_stack.get_var(*var).is_none() {
210                self.error(format!("Found incorrect region var: {var}"));
211            }
212        }
213    }
214    fn enter_ty_kind(&mut self, x: &TyKind) {
215        if let TyKind::TypeVar(var) = x {
216            if self.binder_stack.get_var(*var).is_none() {
217                self.error(format!("Found incorrect type var: {var}"));
218            }
219        }
220    }
221    fn enter_const_generic(&mut self, x: &ConstGeneric) {
222        if let ConstGeneric::Var(var) = x {
223            if self.binder_stack.get_var(*var).is_none() {
224                self.error(format!("Found incorrect const-generic var: {var}"));
225            }
226        }
227    }
228    fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
229        match x {
230            TraitRefKind::Clause(var) => {
231                if self.binder_stack.get_var(*var).is_none() {
232                    self.error(format!("Found incorrect clause var: {var}"));
233                }
234            }
235            TraitRefKind::BuiltinOrAuto {
236                trait_decl_ref,
237                parent_trait_refs,
238                types,
239            } => {
240                let trait_id = trait_decl_ref.skip_binder.id;
241                let target = GenericsSource::item(trait_id);
242                let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
243                    return;
244                };
245                if tdecl
246                    .item_meta
247                    .lang_item
248                    .as_deref()
249                    .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
250                {
251                    // These traits have builtin assoc types that we can't resolve.
252                    return;
253                }
254                let fmt = &self.ctx.into_fmt();
255                let args_fmt = &self.val_fmt_ctx();
256                self.zip_assert_match(
257                    &tdecl.parent_clauses,
258                    parent_trait_refs,
259                    fmt,
260                    args_fmt,
261                    "builtin trait parent clauses",
262                    &target,
263                    |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
264                );
265                let types_match = types.len() == tdecl.types.len()
266                    && tdecl
267                        .types
268                        .iter()
269                        .zip(types.iter())
270                        .all(|(dname, (iname, _, _))| dname == iname);
271                if !types_match {
272                    let target = target.with_ctx(args_fmt);
273                    let a = tdecl.types.iter().format(", ");
274                    let b = types
275                        .iter()
276                        .map(|(_, ty, _)| ty.with_ctx(args_fmt))
277                        .format(", ");
278                    self.error(format!(
279                        "Mismatched types in builtin trait ref:\
280                        \ntarget: {target}\
281                        \nexpected: [{a}]\
282                        \n     got: [{b}]"
283                    ));
284                }
285            }
286            _ => {}
287        }
288    }
289
290    // Check that generics match the parameters of the target item.
291    fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
292        match x.id {
293            TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
294            // TODO: check builtin generics.
295            TypeId::Tuple => {}
296            TypeId::Builtin(_) => {}
297        }
298    }
299    fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
300        self.assert_matches_item(x.id, &x.generics);
301    }
302    fn enter_fn_ptr(&mut self, x: &FnPtr) {
303        match x.func.as_ref() {
304            FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
305                self.assert_matches_item(*id, &x.generics)
306            }
307            // TODO: check builtin generics.
308            FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
309            FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
310                let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
311                self.assert_matches_method(trait_id, method_name, &x.generics);
312            }
313        }
314    }
315    fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
316        self.assert_matches_item(x.id, &x.generics);
317    }
318    fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
319        self.assert_matches_item(x.id, &x.generics);
320    }
321    fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
322        self.assert_matches_item(x.id, &x.generics);
323    }
324    fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
325        let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
326            return;
327        };
328        // See `lift_associated_item_clauses`
329        assert!(timpl.type_clauses.is_empty());
330        assert!(tdecl.type_clauses.is_empty());
331
332        let fmt1 = self.ctx.into_fmt();
333        let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
334        let args_fmt = &self.val_fmt_ctx();
335        self.zip_assert_match(
336            &tdecl.parent_clauses,
337            &timpl.parent_trait_refs,
338            &tdecl_fmt,
339            args_fmt,
340            "trait parent clauses",
341            &GenericsSource::item(timpl.impl_trait.id),
342            |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
343        );
344        let types_match = timpl.types.len() == tdecl.types.len()
345            && tdecl
346                .types
347                .iter()
348                .zip(timpl.types.iter())
349                .all(|(dname, (iname, _))| dname == iname);
350        if !types_match {
351            self.error(
352                "The associated types supplied by the trait impl don't match the trait decl.",
353            )
354        }
355        let consts_match = timpl.consts.len() == tdecl.consts.len()
356            && tdecl
357                .types
358                .iter()
359                .zip(timpl.types.iter())
360                .all(|(dname, (iname, _))| dname == iname);
361        if !consts_match {
362            self.error(
363                "The associated consts supplied by the trait impl don't match the trait decl.",
364            )
365        }
366        let methods_match = timpl.methods.len() == tdecl.methods.len();
367        if !methods_match && self.phase != "after translation" {
368            let decl_methods = tdecl
369                .methods()
370                .map(|(name, _)| format!("- {name}"))
371                .join("\n");
372            let impl_methods = timpl
373                .methods()
374                .map(|(name, _)| format!("- {name}"))
375                .join("\n");
376            self.error(format!(
377                "The methods supplied by the trait impl don't match the trait decl.\n\
378                Trait methods:\n{decl_methods}\n\
379                Impl methods:\n{impl_methods}"
380            ))
381        }
382    }
383}
384
385// The argument is a name to disambiguate the two times we run this check.
386pub struct Check(pub &'static str);
387impl TransformPass for Check {
388    fn transform_ctx(&self, ctx: &mut TransformCtx) {
389        for item in ctx.translated.all_items() {
390            // Hack: the items generated by monomorphisation have incorrect generics.
391            if item
392                .item_meta()
393                .name
394                .name
395                .last()
396                .unwrap()
397                .is_monomorphized()
398            {
399                continue;
400            }
401            let mut visitor = CheckGenericsVisitor {
402                ctx,
403                phase: self.0,
404                span: Span::dummy(),
405                binder_stack: BindingStack::empty(),
406                visit_stack: Default::default(),
407            };
408            let _ = item.drive(&mut visitor);
409        }
410    }
411}