1use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8 ast::*,
9 errors::Level,
10 formatter::{AstFormatter, FmtCtx, IntoFormatter},
11 pretty::FmtWithCtx,
12};
13
14use super::{ctx::TransformPass, TransformCtx};
15
16#[derive(Visitor)]
17struct CheckGenericsVisitor<'a> {
18 ctx: &'a TransformCtx,
19 phase: &'static str,
20 span: Span,
22 binder_stack: BindingStack<GenericParams>,
26 visit_stack: Vec<&'static str>,
28}
29
30impl CheckGenericsVisitor<'_> {
31 fn error(&self, message: impl Display) {
32 let msg = format!(
33 "Found inconsistent generics {}:\n{message}\n\
34 Visitor stack:\n {}\n\
35 Binding stack (depth {}):\n {}",
36 self.phase,
37 self.visit_stack.iter().rev().join("\n "),
38 self.binder_stack.len(),
39 self.binder_stack
40 .iter_enumerated()
41 .map(|(i, params)| format!("{i}: {params}"))
42 .join("\n "),
43 );
44 self.ctx.span_err(self.span, &msg, Level::ERROR);
46 }
47
48 fn val_fmt_ctx(&self) -> FmtCtx<'_> {
52 let mut fmt = self.ctx.into_fmt();
53 fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
54 fmt
55 }
56
57 fn zip_assert_match<I, A, B, FmtA, FmtB>(
58 &self,
59 a: &Vector<I, A>,
60 b: &Vector<I, B>,
61 a_fmt: &FmtA,
62 b_fmt: &FmtB,
63 kind: &str,
64 target: &GenericsSource,
65 check_inner: impl Fn(&A, &B),
66 ) where
67 I: Idx,
68 FmtA: AstFormatter,
69 A: FmtWithCtx<FmtA>,
70 B: FmtWithCtx<FmtB>,
71 {
72 if a.elem_count() == b.elem_count() {
73 a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
74 } else {
75 let a = a.iter().map(|x| x.with_ctx(a_fmt)).join(", ");
76 let b = b.iter().map(|x| x.with_ctx(b_fmt)).join(", ");
77 let target = target.with_ctx(a_fmt);
78 self.error(format!(
79 "Mismatched {kind}:\
80 \ntarget: {target}\
81 \nexpected: [{a}]\
82 \n got: [{b}]"
83 ))
84 }
85 }
86
87 fn assert_clause_matches(
88 &self,
89 params_fmt: &FmtCtx<'_>,
90 tclause: &TraitClause,
91 tref: &TraitRef,
92 ) {
93 let clause_trait_id = tclause.trait_.skip_binder.trait_id;
94 let ref_trait_id = tref.trait_decl_ref.skip_binder.trait_id;
95 if clause_trait_id != ref_trait_id {
96 let args_fmt = &self.val_fmt_ctx();
97 let tclause = tclause.with_ctx(params_fmt);
98 let tref_pred = tref.trait_decl_ref.with_ctx(args_fmt);
99 let tref = tref.with_ctx(args_fmt);
100 self.error(format!(
101 "Mismatched trait clause:\
102 \nexpected: {tclause}\
103 \n got: {tref}: {tref_pred}"
104 ));
105 }
106 }
107
108 fn assert_matches(&self, params_fmt: &FmtCtx<'_>, params: &GenericParams, args: &GenericArgs) {
109 let args_fmt = &self.val_fmt_ctx();
110 self.zip_assert_match(
111 ¶ms.regions,
112 &args.regions,
113 params_fmt,
114 args_fmt,
115 "regions",
116 &args.target,
117 |_, _| {},
118 );
119 self.zip_assert_match(
120 ¶ms.types,
121 &args.types,
122 params_fmt,
123 args_fmt,
124 "type generics",
125 &args.target,
126 |_, _| {},
127 );
128 self.zip_assert_match(
129 ¶ms.const_generics,
130 &args.const_generics,
131 params_fmt,
132 args_fmt,
133 "const generics",
134 &args.target,
135 |_, _| {},
136 );
137 self.zip_assert_match(
138 ¶ms.trait_clauses,
139 &args.trait_refs,
140 params_fmt,
141 args_fmt,
142 "trait clauses",
143 &args.target,
144 |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
145 );
146 }
147}
148
149impl VisitAst for CheckGenericsVisitor<'_> {
150 fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
151 self.visit_stack.push(x.name());
152 x.drive(self)?; self.visit_stack.pop();
154 Continue(())
155 }
156
157 fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
158 self.binder_stack.push(binder.params.clone());
159 self.visit_inner(binder)?;
160 self.binder_stack.pop();
161 Continue(())
162 }
163 fn visit_region_binder<T: AstVisitable>(
164 &mut self,
165 binder: &RegionBinder<T>,
166 ) -> ControlFlow<Self::Break> {
167 self.binder_stack.push(GenericParams {
168 regions: binder.regions.clone(),
169 ..Default::default()
170 });
171 self.visit_inner(binder)?;
172 self.binder_stack.pop();
173 Continue(())
174 }
175
176 fn enter_region(&mut self, x: &Region) {
177 if let Region::Var(var) = x {
178 if self.binder_stack.get_var(*var).is_none() {
179 self.error(format!("Found incorrect region var: {var}"));
180 }
181 }
182 }
183 fn enter_ty_kind(&mut self, x: &TyKind) {
184 if let TyKind::TypeVar(var) = x {
185 if self.binder_stack.get_var(*var).is_none() {
186 self.error(format!("Found incorrect type var: {var}"));
187 }
188 }
189 }
190 fn enter_const_generic(&mut self, x: &ConstGeneric) {
191 if let ConstGeneric::Var(var) = x {
192 if self.binder_stack.get_var(*var).is_none() {
193 self.error(format!("Found incorrect const-generic var: {var}"));
194 }
195 }
196 }
197 fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
198 if let TraitRefKind::Clause(var) = x {
199 if self.binder_stack.get_var(*var).is_none() {
200 self.error(format!("Found incorrect clause var: {var}"));
201 }
202 }
203 }
204
205 fn visit_aggregate_kind(&mut self, agg: &AggregateKind) -> ControlFlow<Self::Break> {
206 match agg {
207 AggregateKind::Adt(..) | AggregateKind::Array(..) | AggregateKind::RawPtr(..) => {
208 self.visit_inner(agg)?
209 }
210 }
211 Continue(())
212 }
213
214 fn enter_generic_args(&mut self, args: &GenericArgs) {
215 let fmt1;
216 let fmt2;
217 let (params, params_fmt) = match &args.target {
218 GenericsSource::Item(item_id) => {
219 let Some(item) = self.ctx.translated.get_item(*item_id) else {
220 return;
221 };
222 let params = item.generic_params();
223 fmt1 = self.ctx.into_fmt();
224 let fmt = fmt1.push_binder(Cow::Borrowed(params));
225 (params, fmt)
226 }
227 GenericsSource::Method(trait_id, method_name) => {
228 let Some(trait_decl) = self.ctx.translated.trait_decls.get(*trait_id) else {
229 return;
230 };
231 let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name)
232 else {
233 return;
234 };
235 let params = &bound_fn.params;
236 fmt1 = self.ctx.into_fmt();
237 fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
238 let fmt = fmt2.push_binder(Cow::Borrowed(params));
239 (params, fmt)
240 }
241 GenericsSource::Builtin => return,
242 GenericsSource::Other => {
243 self.error("`GenericsSource::Other` should not exist in the charon AST");
244 return;
245 }
246 };
247 self.assert_matches(¶ms_fmt, params, args);
248 }
249
250 fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
252 let Some(tdecl) = self
253 .ctx
254 .translated
255 .trait_decls
256 .get(timpl.impl_trait.trait_id)
257 else {
258 return;
259 };
260 assert!(timpl.type_clauses.is_empty());
262 assert!(tdecl.type_clauses.is_empty());
263
264 let fmt1 = self.ctx.into_fmt();
265 let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
266 let args_fmt = &self.val_fmt_ctx();
267 self.zip_assert_match(
268 &tdecl.parent_clauses,
269 &timpl.parent_trait_refs,
270 &tdecl_fmt,
271 args_fmt,
272 "trait parent clauses",
273 &GenericsSource::item(timpl.impl_trait.trait_id),
274 |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
275 );
276 let types_match = timpl.types.len() == tdecl.types.len()
277 && tdecl
278 .types
279 .iter()
280 .zip(timpl.types.iter())
281 .all(|(dname, (iname, _))| dname == iname);
282 if !types_match {
283 self.error(
284 "The associated types supplied by the trait impl don't match the trait decl.",
285 )
286 }
287 let consts_match = timpl.consts.len() == tdecl.consts.len()
288 && tdecl
289 .types
290 .iter()
291 .zip(timpl.types.iter())
292 .all(|(dname, (iname, _))| dname == iname);
293 if !consts_match {
294 self.error(
295 "The associated consts supplied by the trait impl don't match the trait decl.",
296 )
297 }
298 let methods_match = timpl.methods.len() == tdecl.methods.len();
299 if !methods_match && self.phase != "after translation" {
300 let decl_methods = tdecl
301 .methods()
302 .map(|(name, _)| format!("- {name}"))
303 .join("\n");
304 let impl_methods = timpl
305 .methods()
306 .map(|(name, _)| format!("- {name}"))
307 .join("\n");
308 self.error(format!(
309 "The methods supplied by the trait impl don't match the trait decl.\n\
310 Trait methods:\n{decl_methods}\n\
311 Impl methods:\n{impl_methods}"
312 ))
313 }
314 }
315
316 fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
317 let old_span = self.span;
319 self.span = st.span;
320 self.visit_inner(st)?;
321 self.span = old_span;
322 Continue(())
323 }
324
325 fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
326 let old_span = self.span;
328 self.span = st.span;
329 self.visit_inner(st)?;
330 self.span = old_span;
331 Continue(())
332 }
333}
334
335pub struct Check(pub &'static str);
337impl TransformPass for Check {
338 fn transform_ctx(&self, ctx: &mut TransformCtx) {
339 for item in ctx.translated.all_items() {
340 let mut visitor = CheckGenericsVisitor {
341 ctx,
342 phase: self.0,
343 span: item.item_meta().span,
344 binder_stack: BindingStack::new(item.generic_params().clone()),
345 visit_stack: Default::default(),
346 };
347 let _ = item.drive(&mut visitor);
348 }
349 }
350}