1use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8 ast::*,
9 errors::Level,
10 formatter::{AstFormatter, FmtCtx, IntoFormatter},
11 pretty::FmtWithCtx,
12 transform::utils::GenericsSource,
13};
14
15use super::{ctx::TransformPass, TransformCtx};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19 ctx: &'a TransformCtx,
20 phase: &'static str,
21 span: Span,
23 binder_stack: BindingStack<GenericParams>,
27 visit_stack: Vec<&'static str>,
29}
30
31impl CheckGenericsVisitor<'_> {
32 fn error(&self, message: impl Display) {
33 let msg = format!(
34 "Found inconsistent generics {}:\n{message}\n\
35 Visitor stack:\n {}\n\
36 Binding stack (depth {}):\n {}",
37 self.phase,
38 self.visit_stack.iter().rev().join("\n "),
39 self.binder_stack.len(),
40 self.binder_stack
41 .iter_enumerated()
42 .map(|(i, params)| format!("{i}: {params}"))
43 .join("\n "),
44 );
45 self.ctx.span_err(self.span, &msg, Level::ERROR);
47 }
48
49 fn val_fmt_ctx(&self) -> FmtCtx<'_> {
53 let mut fmt = self.ctx.into_fmt();
54 fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
55 fmt
56 }
57
58 fn zip_assert_match<I, A, B, FmtA, FmtB>(
59 &self,
60 a: &Vector<I, A>,
61 b: &Vector<I, B>,
62 a_fmt: &FmtA,
63 b_fmt: &FmtB,
64 kind: &str,
65 target: &GenericsSource,
66 check_inner: impl Fn(&A, &B),
67 ) where
68 I: Idx,
69 FmtA: AstFormatter,
70 A: FmtWithCtx<FmtA>,
71 B: FmtWithCtx<FmtB>,
72 {
73 if a.elem_count() == b.elem_count() {
74 a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
75 } else {
76 let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
77 let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
78 let target = target.with_ctx(a_fmt);
79 self.error(format!(
80 "Mismatched {kind}:\
81 \ntarget: {target}\
82 \nexpected: [{a}]\
83 \n got: [{b}]"
84 ))
85 }
86 }
87
88 fn assert_clause_matches(
89 &self,
90 params_fmt: &FmtCtx<'_>,
91 tclause: &TraitClause,
92 tref: &TraitRef,
93 ) {
94 let clause_trait_id = tclause.trait_.skip_binder.id;
95 let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
96 if clause_trait_id != ref_trait_id {
97 let args_fmt = &self.val_fmt_ctx();
98 let tclause = tclause.with_ctx(params_fmt);
99 let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
100 let tref = tref.with_ctx(args_fmt);
101 self.error(format!(
102 "Mismatched trait clause:\
103 \nexpected: {tclause}\
104 \n got: {tref}: {tref_pred}"
105 ));
106 }
107 }
108
109 fn assert_matches(
110 &self,
111 params_fmt: &FmtCtx<'_>,
112 params: &GenericParams,
113 args: &GenericArgs,
114 target: &GenericsSource,
115 ) {
116 let args_fmt = &self.val_fmt_ctx();
117 self.zip_assert_match(
118 ¶ms.regions,
119 &args.regions,
120 params_fmt,
121 args_fmt,
122 "regions",
123 target,
124 |_, _| {},
125 );
126 self.zip_assert_match(
127 ¶ms.types,
128 &args.types,
129 params_fmt,
130 args_fmt,
131 "type generics",
132 target,
133 |_, _| {},
134 );
135 self.zip_assert_match(
136 ¶ms.const_generics,
137 &args.const_generics,
138 params_fmt,
139 args_fmt,
140 "const generics",
141 target,
142 |_, _| {},
143 );
144 self.zip_assert_match(
145 ¶ms.trait_clauses,
146 &args.trait_refs,
147 params_fmt,
148 args_fmt,
149 "trait clauses",
150 target,
151 |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
152 );
153 }
154
155 fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
156 let id = id.into();
157 let Some(item) = self.ctx.translated.get_item(id) else {
158 return;
159 };
160 let params = item.generic_params();
161 let fmt1 = self.ctx.into_fmt();
162 let fmt = fmt1.push_binder(Cow::Borrowed(params));
163 self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
164 }
165
166 fn assert_matches_method(
167 &self,
168 trait_id: TraitDeclId,
169 method_name: &TraitItemName,
170 args: &GenericArgs,
171 ) {
172 let target = &GenericsSource::Method(trait_id, method_name.clone());
173 let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
174 return;
175 };
176 let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name) else {
177 return;
178 };
179 let params = &bound_fn.params;
180 let fmt1 = self.ctx.into_fmt();
181 let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
182 let fmt = fmt2.push_binder(Cow::Borrowed(params));
183 self.assert_matches(&fmt, params, args, target);
184 }
185}
186
187impl VisitAst for CheckGenericsVisitor<'_> {
188 fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
189 self.visit_stack.push(x.name());
190 x.drive(self)?; self.visit_stack.pop();
192 Continue(())
193 }
194
195 fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
196 self.binder_stack.push(binder.params.clone());
197 self.visit_inner(binder)?;
198 self.binder_stack.pop();
199 Continue(())
200 }
201 fn visit_region_binder<T: AstVisitable>(
202 &mut self,
203 binder: &RegionBinder<T>,
204 ) -> ControlFlow<Self::Break> {
205 self.binder_stack.push(GenericParams {
206 regions: binder.regions.clone(),
207 ..Default::default()
208 });
209 self.visit_inner(binder)?;
210 self.binder_stack.pop();
211 Continue(())
212 }
213
214 fn enter_region(&mut self, x: &Region) {
216 if let Region::Var(var) = x {
217 if self.binder_stack.get_var(*var).is_none() {
218 self.error(format!("Found incorrect region var: {var}"));
219 }
220 }
221 }
222 fn enter_ty_kind(&mut self, x: &TyKind) {
223 if let TyKind::TypeVar(var) = x {
224 if self.binder_stack.get_var(*var).is_none() {
225 self.error(format!("Found incorrect type var: {var}"));
226 }
227 }
228 }
229 fn enter_const_generic(&mut self, x: &ConstGeneric) {
230 if let ConstGeneric::Var(var) = x {
231 if self.binder_stack.get_var(*var).is_none() {
232 self.error(format!("Found incorrect const-generic var: {var}"));
233 }
234 }
235 }
236 fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
237 match x {
238 TraitRefKind::Clause(var) => {
239 if self.binder_stack.get_var(*var).is_none() {
240 self.error(format!("Found incorrect clause var: {var}"));
241 }
242 }
243 TraitRefKind::BuiltinOrAuto {
244 trait_decl_ref,
245 parent_trait_refs,
246 types,
247 } => {
248 let trait_id = trait_decl_ref.skip_binder.id;
249 let target = GenericsSource::item(trait_id);
250 let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
251 return;
252 };
253 if tdecl
254 .item_meta
255 .lang_item
256 .as_deref()
257 .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
258 {
259 return;
261 }
262 let fmt = &self.ctx.into_fmt();
263 let args_fmt = &self.val_fmt_ctx();
264 self.zip_assert_match(
265 &tdecl.parent_clauses,
266 parent_trait_refs,
267 fmt,
268 args_fmt,
269 "builtin trait parent clauses",
270 &target,
271 |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
272 );
273 let types_match = types.len() == tdecl.types.len()
274 && tdecl
275 .types
276 .iter()
277 .zip(types.iter())
278 .all(|(dname, (iname, _, _))| dname == iname);
279 if !types_match {
280 let target = target.with_ctx(args_fmt);
281 let a = tdecl.types.iter().format(", ");
282 let b = types
283 .iter()
284 .map(|(_, ty, _)| ty.with_ctx(args_fmt))
285 .format(", ");
286 self.error(format!(
287 "Mismatched types in builtin trait ref:\
288 \ntarget: {target}\
289 \nexpected: [{a}]\
290 \n got: [{b}]"
291 ));
292 }
293 }
294 _ => {}
295 }
296 }
297
298 fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
300 match x.id {
301 TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
302 TypeId::Tuple => {}
304 TypeId::Builtin(_) => {}
305 }
306 }
307 fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
308 self.assert_matches_item(x.id, &x.generics);
309 }
310 fn enter_fn_ptr(&mut self, x: &FnPtr) {
311 match x.func.as_ref() {
312 FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
313 self.assert_matches_item(*id, &x.generics)
314 }
315 FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
317 FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
318 let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
319 self.assert_matches_method(trait_id, method_name, &x.generics);
320 }
321 }
322 }
323 fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
324 self.assert_matches_item(x.id, &x.generics);
325 }
326 fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
327 self.assert_matches_item(x.id, &x.generics);
328 }
329 fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
330 self.assert_matches_item(x.id, &x.generics);
331 }
332 fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
333 let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
334 return;
335 };
336 assert!(timpl.type_clauses.is_empty());
338 assert!(tdecl.type_clauses.is_empty());
339
340 let fmt1 = self.ctx.into_fmt();
341 let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
342 let args_fmt = &self.val_fmt_ctx();
343 self.zip_assert_match(
344 &tdecl.parent_clauses,
345 &timpl.parent_trait_refs,
346 &tdecl_fmt,
347 args_fmt,
348 "trait parent clauses",
349 &GenericsSource::item(timpl.impl_trait.id),
350 |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
351 );
352 let types_match = timpl.types.len() == tdecl.types.len()
353 && tdecl
354 .types
355 .iter()
356 .zip(timpl.types.iter())
357 .all(|(dname, (iname, _))| dname == iname);
358 if !types_match {
359 self.error(
360 "The associated types supplied by the trait impl don't match the trait decl.",
361 )
362 }
363 let consts_match = timpl.consts.len() == tdecl.consts.len()
364 && tdecl
365 .types
366 .iter()
367 .zip(timpl.types.iter())
368 .all(|(dname, (iname, _))| dname == iname);
369 if !consts_match {
370 self.error(
371 "The associated consts supplied by the trait impl don't match the trait decl.",
372 )
373 }
374 let methods_match = timpl.methods.len() == tdecl.methods.len();
375 if !methods_match && self.phase != "after translation" {
376 let decl_methods = tdecl
377 .methods()
378 .map(|(name, _)| format!("- {name}"))
379 .join("\n");
380 let impl_methods = timpl
381 .methods()
382 .map(|(name, _)| format!("- {name}"))
383 .join("\n");
384 self.error(format!(
385 "The methods supplied by the trait impl don't match the trait decl.\n\
386 Trait methods:\n{decl_methods}\n\
387 Impl methods:\n{impl_methods}"
388 ))
389 }
390 }
391
392 fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
394 let old_span = self.span;
395 self.span = st.span;
396 self.visit_inner(st)?;
397 self.span = old_span;
398 Continue(())
399 }
400 fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
401 let old_span = self.span;
402 self.span = st.span;
403 self.visit_inner(st)?;
404 self.span = old_span;
405 Continue(())
406 }
407}
408
409pub struct Check(pub &'static str);
411impl TransformPass for Check {
412 fn transform_ctx(&self, ctx: &mut TransformCtx) {
413 for item in ctx.translated.all_items() {
414 if item
416 .item_meta()
417 .name
418 .name
419 .last()
420 .unwrap()
421 .is_monomorphized()
422 {
423 continue;
424 }
425 let mut visitor = CheckGenericsVisitor {
426 ctx,
427 phase: self.0,
428 span: item.item_meta().span,
429 binder_stack: BindingStack::new(item.generic_params().clone()),
430 visit_stack: Default::default(),
431 };
432 let _ = item.drive(&mut visitor);
433 }
434 }
435}