1use derive_generic_visitor::*;
3use index_vec::Idx;
4use itertools::Itertools;
5use std::{borrow::Cow, fmt::Display};
6
7use crate::{
8 ast::*,
9 errors::Level,
10 formatter::{FmtCtx, IntoFormatter, PushBinder},
11 pretty::FmtWithCtx,
12};
13
14use super::{ctx::TransformPass, TransformCtx};
15
16#[derive(Visitor)]
17struct CheckGenericsVisitor<'a> {
18 ctx: &'a TransformCtx,
19 phase: &'static str,
20 span: Span,
22 binder_stack: BindingStack<GenericParams>,
26 visit_stack: Vec<&'static str>,
28}
29
30impl CheckGenericsVisitor<'_> {
31 fn error(&self, message: impl Display) {
32 let msg = format!(
33 "Found inconsistent generics {}:\n{message}\n\
34 Visitor stack:\n {}\n\
35 Binding stack (depth {}):\n {}",
36 self.phase,
37 self.visit_stack.iter().rev().join("\n "),
38 self.binder_stack.len(),
39 self.binder_stack
40 .iter_enumerated()
41 .map(|(i, params)| format!("{i}: {params}"))
42 .join("\n "),
43 );
44 self.ctx.span_err(self.span, &msg, Level::ERROR);
46 }
47
48 fn val_fmt_ctx(&self) -> FmtCtx<'_> {
52 let mut fmt = self.ctx.into_fmt();
53 fmt.generics = self.binder_stack.map_ref(Cow::Borrowed);
54 fmt
55 }
56
57 fn zip_assert_match<I, A, B, FmtA, FmtB>(
58 &self,
59 a: &Vector<I, A>,
60 b: &Vector<I, B>,
61 a_fmt: &FmtA,
62 b_fmt: &FmtB,
63 kind: &str,
64 check_inner: impl Fn(&A, &B),
65 ) where
66 I: Idx,
67 A: for<'a> FmtWithCtx<FmtA>,
68 B: for<'a> FmtWithCtx<FmtB>,
69 {
70 if a.elem_count() == b.elem_count() {
71 a.iter().zip(b.iter()).for_each(|(x, y)| check_inner(x, y));
72 } else {
73 let a = a.iter().map(|x| x.fmt_with_ctx(a_fmt)).join(", ");
74 let b = b.iter().map(|x| x.fmt_with_ctx(b_fmt)).join(", ");
75 self.error(format!(
76 "Mismatched {kind}:\
77 \nexpected: [{a}]\
78 \n got: [{b}]"
79 ))
80 }
81 }
82
83 fn assert_clause_matches(
84 &self,
85 params_fmt: &FmtCtx<'_>,
86 tclause: &TraitClause,
87 tref: &TraitRef,
88 ) {
89 let clause_trait_id = tclause.trait_.skip_binder.trait_id;
90 let ref_trait_id = tref.trait_decl_ref.skip_binder.trait_id;
91 if clause_trait_id != ref_trait_id {
92 let args_fmt = &self.val_fmt_ctx();
93 let tclause = tclause.fmt_with_ctx(params_fmt);
94 let tref_pred = tref.trait_decl_ref.fmt_with_ctx(args_fmt);
95 let tref = tref.fmt_with_ctx(args_fmt);
96 self.error(format!(
97 "Mismatched trait clause:\
98 \nexpected: {tclause}\
99 \n got: {tref}: {tref_pred}"
100 ));
101 }
102 }
103
104 fn assert_matches(&self, params_fmt: &FmtCtx<'_>, params: &GenericParams, args: &GenericArgs) {
105 let args_fmt = &self.val_fmt_ctx();
106 self.zip_assert_match(
107 ¶ms.regions,
108 &args.regions,
109 params_fmt,
110 args_fmt,
111 "regions",
112 |_, _| {},
113 );
114 self.zip_assert_match(
115 ¶ms.types,
116 &args.types,
117 params_fmt,
118 args_fmt,
119 "type generics",
120 |_, _| {},
121 );
122 self.zip_assert_match(
123 ¶ms.const_generics,
124 &args.const_generics,
125 params_fmt,
126 args_fmt,
127 "const generics",
128 |_, _| {},
129 );
130 self.zip_assert_match(
131 ¶ms.trait_clauses,
132 &args.trait_refs,
133 params_fmt,
134 args_fmt,
135 "trait clauses",
136 |tclause, tref| self.assert_clause_matches(params_fmt, tclause, tref),
137 );
138 }
139}
140
141impl VisitAst for CheckGenericsVisitor<'_> {
142 fn visit<'a, T: AstVisitable>(&'a mut self, x: &T) -> ControlFlow<Self::Break> {
143 self.visit_stack.push(x.name());
144 x.drive(self)?; self.visit_stack.pop();
146 Continue(())
147 }
148
149 fn visit_binder<T: AstVisitable>(&mut self, binder: &Binder<T>) -> ControlFlow<Self::Break> {
150 self.binder_stack.push(binder.params.clone());
151 self.visit_inner(binder)?;
152 self.binder_stack.pop();
153 Continue(())
154 }
155 fn visit_region_binder<T: AstVisitable>(
156 &mut self,
157 binder: &RegionBinder<T>,
158 ) -> ControlFlow<Self::Break> {
159 self.binder_stack.push(GenericParams {
160 regions: binder.regions.clone(),
161 ..Default::default()
162 });
163 self.visit_inner(binder)?;
164 self.binder_stack.pop();
165 Continue(())
166 }
167
168 fn enter_region(&mut self, x: &Region) {
169 if let Region::Var(var) = x {
170 if self.binder_stack.get_var(*var).is_none() {
171 self.error(format!("Found incorrect region var: {var}"));
172 }
173 }
174 }
175 fn enter_ty_kind(&mut self, x: &TyKind) {
176 if let TyKind::TypeVar(var) = x {
177 if self.binder_stack.get_var(*var).is_none() {
178 self.error(format!("Found incorrect type var: {var}"));
179 }
180 }
181 }
182 fn enter_const_generic(&mut self, x: &ConstGeneric) {
183 if let ConstGeneric::Var(var) = x {
184 if self.binder_stack.get_var(*var).is_none() {
185 self.error(format!("Found incorrect const-generic var: {var}"));
186 }
187 }
188 }
189 fn enter_trait_ref_kind(&mut self, x: &TraitRefKind) {
190 if let TraitRefKind::Clause(var) = x {
191 if self.binder_stack.get_var(*var).is_none() {
192 self.error(format!("Found incorrect clause var: {var}"));
193 }
194 }
195 }
196
197 fn visit_aggregate_kind(&mut self, agg: &AggregateKind) -> ControlFlow<Self::Break> {
198 match agg {
199 AggregateKind::Adt(..) | AggregateKind::Array(..) | AggregateKind::RawPtr(..) => {
200 self.visit_inner(agg)?
201 }
202 AggregateKind::Closure(_id, args) => {
203 self.visit_inner(args)?
207 }
208 }
209 Continue(())
210 }
211
212 fn enter_generic_args(&mut self, args: &GenericArgs) {
213 let fmt1;
214 let fmt2;
215 let (params, params_fmt) = match &args.target {
216 GenericsSource::Item(item_id) => {
217 let Some(item) = self.ctx.translated.get_item(*item_id) else {
218 return;
219 };
220 let params = item.generic_params();
221 fmt1 = self.ctx.into_fmt();
222 let fmt = fmt1.push_binder(Cow::Borrowed(params));
223 (params, fmt)
224 }
225 GenericsSource::Method(trait_id, method_name) => {
226 let Some(trait_decl) = self.ctx.translated.trait_decls.get(*trait_id) else {
227 return;
228 };
229 let Some((_, bound_fn)) = trait_decl.methods().find(|(n, _)| n == method_name)
230 else {
231 return;
232 };
233 let params = &bound_fn.params;
234 fmt1 = self.ctx.into_fmt();
235 fmt2 = fmt1.push_binder(Cow::Borrowed(&trait_decl.generics));
236 let fmt = fmt2.push_binder(Cow::Borrowed(params));
237 (params, fmt)
238 }
239 GenericsSource::Builtin => return,
240 GenericsSource::Other => {
241 self.error("`GenericsSource::Other` should not exist in the charon AST");
242 return;
243 }
244 };
245 self.assert_matches(¶ms_fmt, params, args);
246 }
247
248 fn enter_trait_impl(&mut self, timpl: &TraitImpl) {
250 let Some(tdecl) = self
251 .ctx
252 .translated
253 .trait_decls
254 .get(timpl.impl_trait.trait_id)
255 else {
256 return;
257 };
258 assert!(timpl.type_clauses.is_empty());
260 assert!(tdecl.type_clauses.is_empty());
261
262 let fmt1 = self.ctx.into_fmt();
263 let tdecl_fmt = fmt1.push_binder(Cow::Borrowed(&tdecl.generics));
264 let args_fmt = &self.val_fmt_ctx();
265 self.zip_assert_match(
266 &tdecl.parent_clauses,
267 &timpl.parent_trait_refs,
268 &tdecl_fmt,
269 args_fmt,
270 "trait parent clauses",
271 |tclause, tref| self.assert_clause_matches(&tdecl_fmt, tclause, tref),
272 );
273 let types_match = timpl.types.len() == tdecl.types.len()
274 && tdecl
275 .types
276 .iter()
277 .zip(timpl.types.iter())
278 .all(|(dname, (iname, _))| dname == iname);
279 if !types_match {
280 self.error(
281 "The associated types supplied by the trait impl don't match the trait decl.",
282 )
283 }
284 let consts_match = timpl.consts.len() == tdecl.consts.len()
285 && tdecl
286 .types
287 .iter()
288 .zip(timpl.types.iter())
289 .all(|(dname, (iname, _))| dname == iname);
290 if !consts_match {
291 self.error(
292 "The associated consts supplied by the trait impl don't match the trait decl.",
293 )
294 }
295 let methods_match = timpl.methods.len() == tdecl.methods.len();
296 if !methods_match && self.phase != "after translation" {
297 let decl_methods = tdecl
298 .methods()
299 .map(|(name, _)| format!("- {name}"))
300 .join("\n");
301 let impl_methods = timpl
302 .methods()
303 .map(|(name, _)| format!("- {name}"))
304 .join("\n");
305 self.error(format!(
306 "The methods supplied by the trait impl don't match the trait decl.\n\
307 Trait methods:\n{decl_methods}\n\
308 Impl methods:\n{impl_methods}"
309 ))
310 }
311 }
312
313 fn visit_ullbc_statement(&mut self, st: &ullbc_ast::Statement) -> ControlFlow<Self::Break> {
314 let old_span = self.span;
316 self.span = st.span;
317 self.visit_inner(st)?;
318 self.span = old_span;
319 Continue(())
320 }
321
322 fn visit_llbc_statement(&mut self, st: &llbc_ast::Statement) -> ControlFlow<Self::Break> {
323 let old_span = self.span;
325 self.span = st.span;
326 self.visit_inner(st)?;
327 self.span = old_span;
328 Continue(())
329 }
330}
331
332pub struct Check(pub &'static str);
334impl TransformPass for Check {
335 fn transform_ctx(&self, ctx: &mut TransformCtx) {
336 for item in ctx.translated.all_items() {
337 let mut visitor = CheckGenericsVisitor {
338 ctx,
339 phase: self.0,
340 span: item.item_meta().span,
341 binder_stack: BindingStack::new(item.generic_params().clone()),
342 visit_stack: Default::default(),
343 };
344 let _ = item.drive(&mut visitor);
345 }
346 }
347}