1use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8 ast::*,
9 errors::Level,
10 formatter::{AstFormatter, FmtCtx, IntoFormatter},
11 pretty::FmtWithCtx,
12 transform::utils::GenericsSource,
13};
14
15use super::{TransformCtx, ctx::TransformPass};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19 ctx: &'a TransformCtx,
20 phase: &'static str,
21 span: Span,
23 binder_stack: BindingStack<GenericParams>,
27 visit_stack: Vec<&'static str>,
29}
30
31impl VisitorWithSpan for CheckGenericsVisitor<'_> {
32 fn current_span(&mut self) -> &mut Span {
33 &mut self.span
34 }
35}
36impl VisitorWithBinderStack for CheckGenericsVisitor<'_> {
37 fn binder_stack_mut(&mut self) -> &mut BindingStack<GenericParams> {
38 &mut self.binder_stack
39 }
40}
41
42impl CheckGenericsVisitor<'_> {
43 fn error(&self, message: impl Display) {
44 let msg = format!(
45 "Found inconsistent generics {}:\n{message}\n\
46 Visitor stack:\n {}\n\
47 Binding stack (depth {}):\n {}",
48 self.phase,
49 self.visit_stack.iter().rev().join("\n "),
50 self.binder_stack.len(),
51 self.binder_stack
52 .iter_enumerated()
53 .map(|(i, params)| format!("{i}: {params}"))
54 .join("\n "),
55 );
56 self.ctx.span_err(self.span, &msg, Level::ERROR);
58 }
59
60 fn val_fmt_ctx(&self) -> FmtCtx<'_> {
64 let mut fmt = self.ctx.into_fmt();
65 fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
66 fmt
67 }
68
69 fn zip_assert_match<I, A, B, FmtA, FmtB>(
70 &self,
71 a: &Vector<I, A>,
72 b: &Vector<I, B>,
73 a_fmt: &FmtA,
74 b_fmt: &FmtB,
75 kind: &str,
76 target: &GenericsSource,
77 check_inner: impl Fn(&A, &B),
78 ) where
79 I: Idx,
80 FmtA: AstFormatter,
81 A: FmtWithCtx<FmtA>,
82 B: FmtWithCtx<FmtB>,
83 {
84 if a.elem_count() == b.elem_count() {
85 a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
86 } else {
87 let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
88 let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
89 let target = target.with_ctx(a_fmt);
90 self.error(format!(
91 "Mismatched {kind}:\
92 \ntarget: {target}\
93 \nexpected: [{a}]\
94 \n got: [{b}]"
95 ))
96 }
97 }
98
99 fn assert_clause_matches(
100 &self,
101 params_fmt: &FmtCtx<'_>,
102 tclause: &TraitClause,
103 tref: &TraitRef,
104 ) {
105 let clause_trait_id = tclause.trait_.skip_binder.id;
106 let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
107 if clause_trait_id != ref_trait_id {
108 let args_fmt = &self.val_fmt_ctx();
109 let tclause = tclause.with_ctx(params_fmt);
110 let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
111 let tref = tref.with_ctx(args_fmt);
112 self.error(format!(
113 "Mismatched trait clause:\
114 \nexpected: {tclause}\
115 \n got: {tref}: {tref_pred}"
116 ));
117 }
118 }
119
120 fn assert_matches(
121 &self,
122 params_fmt: &FmtCtx<'_>,
123 params: &GenericParams,
124 args: &GenericArgs,
125 target: &GenericsSource,
126 ) {
127 let args_fmt = &self.val_fmt_ctx();
128 self.zip_assert_match(
129 ¶ms.regions,
130 &args.regions,
131 params_fmt,
132 args_fmt,
133 "regions",
134 target,
135 |_, _| {},
136 );
137 self.zip_assert_match(
138 ¶ms.types,
139 &args.types,
140 params_fmt,
141 args_fmt,
142 "type generics",
143 target,
144 |_, _| {},
145 );
146 self.zip_assert_match(
147 ¶ms.const_generics,
148 &args.const_generics,
149 params_fmt,
150 args_fmt,
151 "const generics",
152 target,
153 |_, _| {},
154 );
155 self.zip_assert_match(
156 ¶ms.trait_clauses,
157 &args.trait_refs,
158 params_fmt,
159 args_fmt,
160 "trait clauses",
161 target,
162 |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
163 );
164 }
165
166 fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
167 let id = id.into();
168 let Some(item) = self.ctx.translated.get_item(id) else {
169 return;
170 };
171 let params = item.generic_params();
172 let fmt1 = self.ctx.into_fmt();
173 let fmt = fmt1.push_binder(Cow::Borrowed(params));
174 self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
175 }
176
177 fn assert_matches_method(
178 &self,
179 trait_id: TraitDeclId,
180 method_name: &TraitItemName,
181 args: &GenericArgs,
182 ) {
183 let target = &GenericsSource::Method(trait_id, method_name.clone());
184 let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
185 return;
186 };
187 let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name) else {
188 return;
189 };
190 let params = &bound_fn.params;
191 let fmt1 = self.ctx.into_fmt();
192 let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
193 let fmt = fmt2.push_binder(Cow::Borrowed(params));
194 self.assert_matches(&fmt, params, args, target);
195 }
196}
197
198impl VisitAst for CheckGenericsVisitor<'_> {
199 fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
200 self.visit_stack.push(x.name());
201 VisitWithSpan::new(VisitWithBinderStack::new(self)).visit(x)?;
202 self.visit_stack.pop();
203 Continue(())
204 }
205
206 fn enter_region(&mut self, x: &Region) {
208 if let Region::Var(var) = x {
209 if self.binder_stack.get_var(*var).is_none() {
210 self.error(format!("Found incorrect region var: {var}"));
211 }
212 }
213 }
214 fn enter_ty_kind(&mut self, x: &TyKind) {
215 if let TyKind::TypeVar(var) = x {
216 if self.binder_stack.get_var(*var).is_none() {
217 self.error(format!("Found incorrect type var: {var}"));
218 }
219 }
220 }
221 fn enter_const_generic(&mut self, x: &ConstGeneric) {
222 if let ConstGeneric::Var(var) = x {
223 if self.binder_stack.get_var(*var).is_none() {
224 self.error(format!("Found incorrect const-generic var: {var}"));
225 }
226 }
227 }
228 fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
229 match x {
230 TraitRefKind::Clause(var) => {
231 if self.binder_stack.get_var(*var).is_none() {
232 self.error(format!("Found incorrect clause var: {var}"));
233 }
234 }
235 TraitRefKind::BuiltinOrAuto {
236 trait_decl_ref,
237 parent_trait_refs,
238 types,
239 } => {
240 let trait_id = trait_decl_ref.skip_binder.id;
241 let target = GenericsSource::item(trait_id);
242 let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
243 return;
244 };
245 if tdecl
246 .item_meta
247 .lang_item
248 .as_deref()
249 .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
250 {
251 return;
253 }
254 let fmt = &self.ctx.into_fmt();
255 let args_fmt = &self.val_fmt_ctx();
256 self.zip_assert_match(
257 &tdecl.parent_clauses,
258 parent_trait_refs,
259 fmt,
260 args_fmt,
261 "builtin trait parent clauses",
262 &target,
263 |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
264 );
265 let types_match = types.len() == tdecl.types.len()
266 && tdecl
267 .types
268 .iter()
269 .zip(types.iter())
270 .all(|(dname, (iname, _, _))| dname == iname);
271 if !types_match {
272 let target = target.with_ctx(args_fmt);
273 let a = tdecl.types.iter().format(", ");
274 let b = types
275 .iter()
276 .map(|(_, ty, _)| ty.with_ctx(args_fmt))
277 .format(", ");
278 self.error(format!(
279 "Mismatched types in builtin trait ref:\
280 \ntarget: {target}\
281 \nexpected: [{a}]\
282 \n got: [{b}]"
283 ));
284 }
285 }
286 _ => {}
287 }
288 }
289
290 fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
292 match x.id {
293 TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
294 TypeId::Tuple => {}
296 TypeId::Builtin(_) => {}
297 }
298 }
299 fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
300 self.assert_matches_item(x.id, &x.generics);
301 }
302 fn enter_fn_ptr(&mut self, x: &FnPtr) {
303 match x.func.as_ref() {
304 FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
305 self.assert_matches_item(*id, &x.generics)
306 }
307 FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
309 FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
310 let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
311 self.assert_matches_method(trait_id, method_name, &x.generics);
312 }
313 }
314 }
315 fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
316 self.assert_matches_item(x.id, &x.generics);
317 }
318 fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
319 self.assert_matches_item(x.id, &x.generics);
320 }
321 fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
322 self.assert_matches_item(x.id, &x.generics);
323 }
324 fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
325 let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
326 return;
327 };
328 assert!(timpl.type_clauses.is_empty());
330 assert!(tdecl.type_clauses.is_empty());
331
332 let fmt1 = self.ctx.into_fmt();
333 let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
334 let args_fmt = &self.val_fmt_ctx();
335 self.zip_assert_match(
336 &tdecl.parent_clauses,
337 &timpl.parent_trait_refs,
338 &tdecl_fmt,
339 args_fmt,
340 "trait parent clauses",
341 &GenericsSource::item(timpl.impl_trait.id),
342 |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
343 );
344 let types_match = timpl.types.len() == tdecl.types.len()
345 && tdecl
346 .types
347 .iter()
348 .zip(timpl.types.iter())
349 .all(|(dname, (iname, _))| dname == iname);
350 if !types_match {
351 self.error(
352 "The associated types supplied by the trait impl don't match the trait decl.",
353 )
354 }
355 let consts_match = timpl.consts.len() == tdecl.consts.len()
356 && tdecl
357 .types
358 .iter()
359 .zip(timpl.types.iter())
360 .all(|(dname, (iname, _))| dname == iname);
361 if !consts_match {
362 self.error(
363 "The associated consts supplied by the trait impl don't match the trait decl.",
364 )
365 }
366 let methods_match = timpl.methods.len() == tdecl.methods.len();
367 if !methods_match && self.phase != "after translation" {
368 let decl_methods = tdecl
369 .methods()
370 .map(|(name, _)| format!("- {name}"))
371 .join("\n");
372 let impl_methods = timpl
373 .methods()
374 .map(|(name, _)| format!("- {name}"))
375 .join("\n");
376 self.error(format!(
377 "The methods supplied by the trait impl don't match the trait decl.\n\
378 Trait methods:\n{decl_methods}\n\
379 Impl methods:\n{impl_methods}"
380 ))
381 }
382 }
383}
384
385pub struct Check(pub &'static str);
387impl TransformPass for Check {
388 fn transform_ctx(&self, ctx: &mut TransformCtx) {
389 for item in ctx.translated.all_items() {
390 if item
392 .item_meta()
393 .name
394 .name
395 .last()
396 .unwrap()
397 .is_monomorphized()
398 {
399 continue;
400 }
401 let mut visitor = CheckGenericsVisitor {
402 ctx,
403 phase: self.0,
404 span: Span::dummy(),
405 binder_stack: BindingStack::empty(),
406 visit_stack: Default::default(),
407 };
408 let _ = item.drive(&mut visitor);
409 }
410 }
411}