1use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8 ast::*,
9 errors::Level,
10 formatter::{AstFormatter, FmtCtx, IntoFormatter},
11 pretty::FmtWithCtx,
12 transform::utils::GenericsSource,
13};
14
15use super::{TransformCtx, ctx::TransformPass};
16
17#[derive(Visitor)]
18struct CheckGenericsVisitor<'a> {
19 ctx: &'a TransformCtx,
20 phase: &'static str,
21 span: Span,
23 binder_stack: BindingStack<GenericParams>,
27 visit_stack: Vec<&'static str>,
29}
30
31impl VisitorWithSpan for CheckGenericsVisitor<'_> {
32 fn current_span(&mut self) -> &mut Span {
33 &mut self.span
34 }
35}
36impl VisitorWithBinderStack for CheckGenericsVisitor<'_> {
37 fn binder_stack_mut(&mut self) -> &mut BindingStack<GenericParams> {
38 &mut self.binder_stack
39 }
40}
41
42impl CheckGenericsVisitor<'_> {
43 fn error(&self, message: impl Display) {
44 let msg = format!(
45 "Found inconsistent generics {}:\n{message}\n\
46 Visitor stack:\n {}\n\
47 Binding stack (depth {}):\n {}",
48 self.phase,
49 self.visit_stack.iter().rev().join("\n "),
50 self.binder_stack.len(),
51 self.binder_stack
52 .iter_enumerated()
53 .map(|(i, params)| format!("{i}: {params}"))
54 .join("\n "),
55 );
56 self.ctx.span_err(self.span, &msg, Level::ERROR);
58 }
59
60 fn val_fmt_ctx(&self) -> FmtCtx<'_> {
64 let mut fmt = self.ctx.into_fmt();
65 fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
66 fmt
67 }
68
69 fn zip_assert_match<I, A, B, FmtA, FmtB>(
70 &self,
71 a: &Vector<I, A>,
72 b: &Vector<I, B>,
73 a_fmt: &FmtA,
74 b_fmt: &FmtB,
75 kind: &str,
76 target: &GenericsSource,
77 check_inner: impl Fn(&A, &B),
78 ) where
79 I: Idx,
80 FmtA: AstFormatter,
81 A: FmtWithCtx<FmtA>,
82 B: FmtWithCtx<FmtB>,
83 {
84 if a.elem_count() == b.elem_count() {
85 a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
86 } else {
87 let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
88 let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
89 let target = target.with_ctx(a_fmt);
90 self.error(format!(
91 "Mismatched {kind}:\
92 \ntarget: {target}\
93 \nexpected: [{a}]\
94 \n got: [{b}]"
95 ))
96 }
97 }
98
99 fn assert_clause_matches(
100 &self,
101 params_fmt: &FmtCtx<'_>,
102 tclause: &TraitClause,
103 tref: &TraitRef,
104 ) {
105 let clause_trait_id = tclause.trait_.skip_binder.id;
106 let ref_trait_id = tref.trait_decl_ref.skip_binder.id;
107 if clause_trait_id != ref_trait_id {
108 let args_fmt = &self.val_fmt_ctx();
109 let tclause = tclause.with_ctx(params_fmt);
110 let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
111 let tref = tref.with_ctx(args_fmt);
112 self.error(format!(
113 "Mismatched trait clause:\
114 \nexpected: {tclause}\
115 \n got: {tref}: {tref_pred}"
116 ));
117 }
118 }
119
120 fn assert_matches(
121 &self,
122 params_fmt: &FmtCtx<'_>,
123 params: &GenericParams,
124 args: &GenericArgs,
125 target: &GenericsSource,
126 ) {
127 let args_fmt = &self.val_fmt_ctx();
128 self.zip_assert_match(
129 ¶ms.regions,
130 &args.regions,
131 params_fmt,
132 args_fmt,
133 "regions",
134 target,
135 |_, _| {},
136 );
137 self.zip_assert_match(
138 ¶ms.types,
139 &args.types,
140 params_fmt,
141 args_fmt,
142 "type generics",
143 target,
144 |_, _| {},
145 );
146 self.zip_assert_match(
147 ¶ms.const_generics,
148 &args.const_generics,
149 params_fmt,
150 args_fmt,
151 "const generics",
152 target,
153 |_, _| {},
154 );
155 self.zip_assert_match(
156 ¶ms.trait_clauses,
157 &args.trait_refs,
158 params_fmt,
159 args_fmt,
160 "trait clauses",
161 target,
162 |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
163 );
164 }
165
166 fn assert_matches_item(&self, id: impl Into<AnyTransId>, args: &GenericArgs) {
167 let id = id.into();
168 let Some(item) = self.ctx.translated.get_item(id) else {
169 return;
170 };
171 let params = item.generic_params();
172 let fmt1 = self.ctx.into_fmt();
173 let fmt = fmt1.push_binder(Cow::Borrowed(params));
174 self.assert_matches(&fmt, params, args, &GenericsSource::item(id));
175 }
176
177 fn assert_matches_method(
178 &self,
179 trait_id: TraitDeclId,
180 method_name: &TraitItemName,
181 args: &GenericArgs,
182 ) {
183 let target = &GenericsSource::Method(trait_id, method_name.clone());
184 let Some(trait_decl) = self.ctx.translated.trait_decls.get(trait_id) else {
185 return;
186 };
187 let Some(bound_fn) = trait_decl.methods().find(|m| m.name() == method_name) else {
188 return;
189 };
190 let params = &bound_fn.params;
191 let fmt1 = self.ctx.into_fmt();
192 let fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
193 let fmt = fmt2.push_binder(Cow::Borrowed(params));
194 self.assert_matches(&fmt, params, args, target);
195 }
196}
197
198impl VisitAst for CheckGenericsVisitor<'_> {
199 fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
200 self.visit_stack.push(x.name());
201 VisitWithSpan::new(VisitWithBinderStack::new(self)).visit(x)?;
202 self.visit_stack.pop();
203 Continue(())
204 }
205
206 fn enter_region(&mut self, x: &Region) {
208 if let Region::Var(var) = x {
209 if self.binder_stack.get_var(*var).is_none() {
210 self.error(format!("Found incorrect region var: {var}"));
211 }
212 }
213 }
214 fn enter_ty_kind(&mut self, x: &TyKind) {
215 if let TyKind::TypeVar(var) = x {
216 if self.binder_stack.get_var(*var).is_none() {
217 self.error(format!("Found incorrect type var: {var}"));
218 }
219 }
220 }
221 fn enter_const_generic(&mut self, x: &ConstGeneric) {
222 if let ConstGeneric::Var(var) = x {
223 if self.binder_stack.get_var(*var).is_none() {
224 self.error(format!("Found incorrect const-generic var: {var}"));
225 }
226 }
227 }
228 fn enter_trait_ref(&mut self, x: &TraitRef) {
229 match &x.kind {
230 TraitRefKind::Clause(var) => {
231 if self.binder_stack.get_var(*var).is_none() {
232 self.error(format!("Found incorrect clause var: {var}"));
233 }
234 }
235 TraitRefKind::BuiltinOrAuto {
236 parent_trait_refs,
237 types,
238 } => {
239 let trait_id = x.trait_decl_ref.skip_binder.id;
240 let target = GenericsSource::item(trait_id);
241 let Some(tdecl) = self.ctx.translated.trait_decls.get(trait_id) else {
242 return;
243 };
244 if tdecl
245 .item_meta
246 .lang_item
247 .as_deref()
248 .is_some_and(|s| matches!(s, "pointee_trait" | "discriminant_kind"))
249 {
250 return;
252 }
253 let fmt = &self.ctx.into_fmt();
254 let args_fmt = &self.val_fmt_ctx();
255 self.zip_assert_match(
256 &tdecl.parent_clauses,
257 parent_trait_refs,
258 fmt,
259 args_fmt,
260 "builtin trait parent clauses",
261 &target,
262 |tclause, tref| self.assert_clause_matches(&fmt, tclause, tref),
263 );
264 let types_match = types.len() == tdecl.types.len()
265 && tdecl
266 .types
267 .iter()
268 .zip(types.iter())
269 .all(|(dty, (iname, _))| dty.name() == iname);
270 if !types_match {
271 let target = target.with_ctx(args_fmt);
272 let a = tdecl.types.iter().map(|t| t.name()).format(", ");
273 let b = types
274 .iter()
275 .map(|(_, assoc_ty)| assoc_ty.value.with_ctx(args_fmt))
276 .format(", ");
277 self.error(format!(
278 "Mismatched types in builtin trait ref:\
279 \ntarget: {target}\
280 \nexpected: [{a}]\
281 \n got: [{b}]"
282 ));
283 }
284 }
285 _ => {}
286 }
287 }
288
289 fn enter_type_decl_ref(&mut self, x: &TypeDeclRef) {
291 match x.id {
292 TypeId::Adt(id) => self.assert_matches_item(id, &x.generics),
293 TypeId::Tuple => {}
295 TypeId::Builtin(_) => {}
296 }
297 }
298 fn enter_fun_decl_ref(&mut self, x: &FunDeclRef) {
299 self.assert_matches_item(x.id, &x.generics);
300 }
301 fn enter_fn_ptr(&mut self, x: &FnPtr) {
302 match x.func.as_ref() {
303 FunIdOrTraitMethodRef::Fun(FunId::Regular(id)) => {
304 self.assert_matches_item(*id, &x.generics)
305 }
306 FunIdOrTraitMethodRef::Fun(FunId::Builtin(_)) => {}
308 FunIdOrTraitMethodRef::Trait(trait_ref, method_name, _) => {
309 let trait_id = trait_ref.trait_decl_ref.skip_binder.id;
310 self.assert_matches_method(trait_id, method_name, &x.generics);
311 }
312 }
313 }
314 fn enter_global_decl_ref(&mut self, x: &GlobalDeclRef) {
315 self.assert_matches_item(x.id, &x.generics);
316 }
317 fn enter_trait_decl_ref(&mut self, x: &TraitDeclRef) {
318 self.assert_matches_item(x.id, &x.generics);
319 }
320 fn enter_trait_impl_ref(&mut self, x: &TraitImplRef) {
321 self.assert_matches_item(x.id, &x.generics);
322 }
323 fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
324 let Some(tdecl) = self.ctx.translated.trait_decls.get(timpl.impl_trait.id) else {
325 return;
326 };
327
328 let fmt1 = self.ctx.into_fmt();
329 let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
330 let args_fmt = &self.val_fmt_ctx();
331 self.zip_assert_match(
332 &tdecl.parent_clauses,
333 &timpl.parent_trait_refs,
334 &tdecl_fmt,
335 args_fmt,
336 "trait parent clauses",
337 &GenericsSource::item(timpl.impl_trait.id),
338 |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
339 );
340 let types_match = timpl.types.len() == tdecl.types.len()
342 && tdecl
343 .types
344 .iter()
345 .zip(timpl.types.iter())
346 .all(|(dty, (iname, _))| dty.name() == iname);
347 if !types_match {
348 self.error(
349 "The associated types supplied by the trait impl don't match the trait decl.",
350 )
351 }
352 let consts_match = timpl.consts.len() == tdecl.consts.len()
353 && tdecl
354 .consts
355 .iter()
356 .zip(timpl.consts.iter())
357 .all(|(dconst, (iname, _))| &dconst.name == iname);
358 if !consts_match {
359 self.error(
360 "The associated consts supplied by the trait impl don't match the trait decl.",
361 )
362 }
363 let methods_match = timpl.methods.len() == tdecl.methods.len();
364 if !methods_match && self.phase != "after translation" {
365 let decl_methods = tdecl
366 .methods()
367 .map(|m| format!("- {}", m.name()))
368 .join("\n");
369 let impl_methods = timpl
370 .methods()
371 .map(|(name, _)| format!("- {name}"))
372 .join("\n");
373 self.error(format!(
374 "The methods supplied by the trait impl don't match the trait decl.\n\
375 Trait methods:\n{decl_methods}\n\
376 Impl methods:\n{impl_methods}"
377 ))
378 }
379 }
380}
381
382pub struct Check(pub &'static str);
384impl TransformPass for Check {
385 fn transform_ctx(&self, ctx: &mut TransformCtx) {
386 for item in ctx.translated.all_items() {
387 if item
391 .item_meta()
392 .name
393 .name
394 .last()
395 .unwrap()
396 .is_monomorphized()
397 {
398 continue;
399 }
400 let mut visitor = CheckGenericsVisitor {
401 ctx,
402 phase: self.0,
403 span: Span::dummy(),
404 binder_stack: BindingStack::empty(),
405 visit_stack: Default::default(),
406 };
407 let _ = item.drive(&mut visitor);
408 }
409 }
410}