1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::mem;
4
5use hax::{BaseState, Symbol};
6use rustc_middle::ty;
7
8use super::translate_ctx::{ItemTransCtx, TraitImplSource, TransItemSourceKind};
9use charon_lib::ast::*;
10use charon_lib::common::CycleDetector;
11use charon_lib::ids::IndexVec;
12
13#[derive(Debug, Default)]
25pub(crate) struct BindingLevel {
26 pub params: GenericParams,
28 pub is_item_binder: bool,
32 pub early_region_vars: HashMap<hax::EarlyParamRegion, RegionId>,
39 pub bound_region_vars: Vec<RegionId>,
41 pub closure_call_method_region: Option<RegionId>,
43 pub type_vars_map: HashMap<u32, TypeVarId>,
45 pub const_generic_vars_map: HashMap<u32, ConstGenericVarId>,
47 pub closure_upvar_tys: Option<IndexVec<FieldId, Ty>>,
51 pub closure_upvar_regions: Vec<RegionId>,
53 pub used_region_names: HashSet<Symbol>,
56 pub type_trans_cache: HashMap<hax::Ty, Ty>,
60}
61
62fn translate_region_name(s: hax::Symbol) -> Option<String> {
64 let s = s.to_string();
65 if s == "'_" { None } else { Some(s) }
66}
67
68impl BindingLevel {
69 pub(crate) fn new(is_item_binder: bool) -> Self {
70 Self {
71 is_item_binder,
72 ..Default::default()
73 }
74 }
75
76 pub(crate) fn push_early_region(
78 &mut self,
79 region: hax::EarlyParamRegion,
80 mutability: LifetimeMutability,
81 ) -> RegionId {
82 let name = if self.used_region_names.insert(region.name) {
83 translate_region_name(region.name)
84 } else {
85 None
86 };
87 assert!(
89 self.bound_region_vars.is_empty(),
90 "Early regions must be translated before late ones"
91 );
92 let rid = self.params.regions.push_with(|index| RegionParam {
93 index,
94 name,
95 mutability,
96 });
97 self.early_region_vars.insert(region, rid);
98 rid
99 }
100
101 pub(crate) fn push_bound_region(&mut self, region: hax::BoundRegionKind) -> RegionId {
103 use hax::BoundRegionKind::*;
104 let name = match region {
105 Anon => None,
106 NamedForPrinting(symbol) | Named(_, symbol) => translate_region_name(symbol),
107 ClosureEnv => Some("@env".to_owned()),
108 };
109 let rid = self
110 .params
111 .regions
112 .push_with(|index| RegionParam::new(index, name));
113 self.bound_region_vars.push(rid);
114 rid
115 }
116
117 pub fn push_upvar_region(&mut self) -> RegionId {
119 let region_id = self
122 .params
123 .regions
124 .push_with(|index| RegionParam::new(index, None));
125 self.closure_upvar_regions.push(region_id);
126 region_id
127 }
128
129 pub(crate) fn push_type_var(&mut self, rid: u32, name: hax::Symbol) -> TypeVarId {
130 let mut name = name.to_string();
133 if name
134 .chars()
135 .any(|c| !(c.is_ascii_alphanumeric() || c == '_'))
136 {
137 name = format!("T{rid}")
138 }
139 let var_id = self
140 .params
141 .types
142 .push_with(|index| TypeParam { index, name });
143 self.type_vars_map.insert(rid, var_id);
144 var_id
145 }
146
147 pub(crate) fn push_const_generic_var(&mut self, rid: u32, ty: Ty, name: hax::Symbol) {
148 let var_id = self
149 .params
150 .const_generics
151 .push_with(|index| ConstGenericParam {
152 index,
153 name: name.to_string(),
154 ty,
155 });
156 self.const_generic_vars_map.insert(rid, var_id);
157 }
158
159 pub(crate) fn push_params_from_binder(&mut self, binder: hax::Binder<()>) -> Result<(), Error> {
161 assert!(
162 self.bound_region_vars.is_empty(),
163 "Trying to use two binders at the same binding level"
164 );
165 use hax::BoundVariableKind::*;
166 for p in binder.bound_vars {
167 match p {
168 Region(region) => {
169 self.push_bound_region(region);
170 }
171 Ty(_) => {
172 panic!("Unexpected locally bound type variable");
173 }
174 Const => {
175 panic!("Unexpected locally bound const generic variable");
176 }
177 }
178 }
179 Ok(())
180 }
181}
182
183impl<'tcx, 'ctx> ItemTransCtx<'tcx, 'ctx> {
184 pub(crate) fn the_only_binder(&self) -> &BindingLevel {
186 assert_eq!(self.binding_levels.len(), 1);
187 self.innermost_binder()
188 }
189 pub(crate) fn the_only_binder_mut(&mut self) -> &mut BindingLevel {
191 assert_eq!(self.binding_levels.len(), 1);
192 self.innermost_binder_mut()
193 }
194
195 pub(crate) fn outermost_binder(&self) -> &BindingLevel {
196 self.binding_levels.outermost()
197 }
198 pub(crate) fn outermost_binder_mut(&mut self) -> &mut BindingLevel {
199 self.binding_levels.outermost_mut()
200 }
201 pub(crate) fn innermost_binder(&self) -> &BindingLevel {
202 self.binding_levels.innermost()
203 }
204 pub(crate) fn innermost_binder_mut(&mut self) -> &mut BindingLevel {
205 self.binding_levels.innermost_mut()
206 }
207
208 pub(crate) fn outermost_generics(&self) -> &GenericParams {
209 &self.outermost_binder().params
210 }
211 #[expect(dead_code)]
212 pub(crate) fn outermost_generics_mut(&mut self) -> &mut GenericParams {
213 &mut self.outermost_binder_mut().params
214 }
215 pub(crate) fn innermost_generics(&self) -> &GenericParams {
216 &self.innermost_binder().params
217 }
218 pub(crate) fn innermost_generics_mut(&mut self) -> &mut GenericParams {
219 &mut self.innermost_binder_mut().params
220 }
221
222 pub(crate) fn lookup_bound_region(
223 &mut self,
224 span: Span,
225 dbid: hax::DebruijnIndex,
226 var: hax::BoundVar,
227 ) -> Result<RegionDbVar, Error> {
228 let dbid = DeBruijnId::new(dbid);
229 if let Some(rid) = self
230 .binding_levels
231 .get(dbid)
232 .and_then(|bl| bl.bound_region_vars.get(var))
233 {
234 Ok(DeBruijnVar::bound(dbid, *rid))
235 } else {
236 raise_error!(
237 self,
238 span,
239 "Unexpected error: could not find region '{dbid}_{var}"
240 )
241 }
242 }
243
244 pub(crate) fn lookup_param<Id: Copy>(
245 &mut self,
246 span: Span,
247 f: impl for<'a> Fn(&'a BindingLevel) -> Option<Id>,
248 mk_err: impl FnOnce() -> String,
249 ) -> Result<DeBruijnVar<Id>, Error> {
250 for (dbid, bl) in self.binding_levels.iter_enumerated() {
251 if let Some(id) = f(bl) {
252 return Ok(DeBruijnVar::bound(dbid, id));
253 }
254 }
255 let err = mk_err();
256 raise_error!(self, span, "Unexpected error: could not find {}", err)
257 }
258
259 pub(crate) fn lookup_early_region(
260 &mut self,
261 span: Span,
262 region: &hax::EarlyParamRegion,
263 ) -> Result<RegionDbVar, Error> {
264 self.lookup_param(
265 span,
266 |bl| bl.early_region_vars.get(region).copied(),
267 || format!("the region variable {region:?}"),
268 )
269 }
270
271 pub(crate) fn lookup_type_var(
272 &mut self,
273 span: Span,
274 param: &hax::ParamTy,
275 ) -> Result<TypeDbVar, Error> {
276 self.lookup_param(
277 span,
278 |bl| bl.type_vars_map.get(¶m.index).copied(),
279 || format!("the type variable {}", param.name),
280 )
281 }
282
283 pub(crate) fn lookup_const_generic_var(
284 &mut self,
285 span: Span,
286 param: &hax::ParamConst,
287 ) -> Result<ConstGenericDbVar, Error> {
288 self.lookup_param(
289 span,
290 |bl| bl.const_generic_vars_map.get(¶m.index).copied(),
291 || format!("the const generic variable {}", param.name),
292 )
293 }
294
295 pub(crate) fn lookup_clause_var(
296 &mut self,
297 span: Span,
298 mut id: usize,
299 ) -> Result<ClauseDbVar, Error> {
300 let innermost_item_binder_id = self
305 .binding_levels
306 .iter_enumerated()
307 .find(|(_, bl)| bl.is_item_binder)
308 .unwrap()
309 .0;
310 for (dbid, bl) in self.binding_levels.iter_enumerated().rev() {
312 let num_clauses_bound_at_this_level = bl.params.trait_clauses.elem_count();
313 if id < num_clauses_bound_at_this_level || dbid == innermost_item_binder_id {
314 let id = TraitClauseId::from_usize(id);
315 return Ok(DeBruijnVar::bound(dbid, id));
316 } else {
317 id -= num_clauses_bound_at_this_level
318 }
319 }
320 raise_error!(
322 self,
323 span,
324 "Unexpected error: could not find clause variable {}",
325 id
326 )
327 }
328
329 pub(crate) fn push_generic_params(&mut self, generics: &hax::TyGenerics) -> Result<(), Error> {
330 for param in &generics.params {
331 self.push_generic_param(param)?;
332 }
333 Ok(())
334 }
335
336 pub(crate) fn push_generic_param(&mut self, param: &hax::GenericParamDef) -> Result<(), Error> {
337 match ¶m.kind {
338 hax::GenericParamDefKind::Lifetime => {
339 let region = hax::EarlyParamRegion {
340 index: param.index,
341 name: param.name.clone(),
342 };
343 let mutability = self
344 .t_ctx
345 .lt_mutability_computer
346 .compute_lifetime_mutability(
347 &self.hax_state,
348 self.item_src.def_id(),
349 param.index,
350 );
351 let _ = self
352 .innermost_binder_mut()
353 .push_early_region(region, mutability);
354 }
355 hax::GenericParamDefKind::Type { .. } => {
356 let _ = self
357 .innermost_binder_mut()
358 .push_type_var(param.index, param.name);
359 }
360 hax::GenericParamDefKind::Const { ty, .. } => {
361 let span = self.def_span(¶m.def_id);
362 let ty = self.translate_ty(span, ty)?;
365 self.innermost_binder_mut()
366 .push_const_generic_var(param.index, ty, param.name);
367 }
368 }
369
370 Ok(())
371 }
372
373 fn push_late_bound_generics_for_def(
382 &mut self,
383 _span: Span,
384 def: &hax::FullDef,
385 ) -> Result<(), Error> {
386 if let hax::FullDefKind::Fn { sig, .. } | hax::FullDefKind::AssocFn { sig, .. } = def.kind()
387 {
388 let innermost_binder = self.innermost_binder_mut();
389 assert!(innermost_binder.bound_region_vars.is_empty());
390 innermost_binder.push_params_from_binder(sig.rebind(()))?;
391 }
392 Ok(())
393 }
394
395 #[tracing::instrument(skip(self, span, def))]
397 fn push_generics_for_def(&mut self, span: Span, def: &hax::FullDef) -> Result<(), Error> {
398 trace!("{:?}", def.param_env());
399 if let Some(parent_item) = def.typing_parent(self.hax_state()) {
402 let parent_def = self.hax_def(&parent_item)?;
403 self.push_generics_for_def(span, &parent_def)?;
404 }
405 self.push_generics_for_def_without_parents(span, def)?;
406 Ok(())
407 }
408
409 fn push_generics_for_def_without_parents(
412 &mut self,
413 _span: Span,
414 def: &hax::FullDef,
415 ) -> Result<(), Error> {
416 use hax::FullDefKind;
417 if let Some(param_env) = def.param_env() {
418 self.push_generic_params(¶m_env.generics)?;
420 let origin = match &def.kind {
422 FullDefKind::Adt { .. }
423 | FullDefKind::TyAlias { .. }
424 | FullDefKind::AssocTy { .. } => PredicateOrigin::WhereClauseOnType,
425 FullDefKind::Fn { .. }
426 | FullDefKind::AssocFn { .. }
427 | FullDefKind::Closure { .. }
428 | FullDefKind::Const { .. }
429 | FullDefKind::AssocConst { .. }
430 | FullDefKind::Static { .. } => PredicateOrigin::WhereClauseOnFn,
431 FullDefKind::TraitImpl { .. } | FullDefKind::InherentImpl { .. } => {
432 PredicateOrigin::WhereClauseOnImpl
433 }
434 FullDefKind::Trait { .. } | FullDefKind::TraitAlias { .. } => {
435 PredicateOrigin::WhereClauseOnTrait
436 }
437 _ => panic!("Unexpected def: {:?}", def.def_id().kind),
438 };
439 self.register_predicates(¶m_env.predicates, origin.clone())?;
440 }
441
442 Ok(())
443 }
444
445 pub fn translate_item_generics(
453 &mut self,
454 span: Span,
455 def: &hax::FullDef,
456 kind: &TransItemSourceKind,
457 ) -> Result<(), Error> {
458 assert!(self.binding_levels.len() == 0);
459 self.binding_levels.push(BindingLevel::new(true));
460 self.push_generics_for_def(span, def)?;
461 self.push_late_bound_generics_for_def(span, def)?;
462
463 if let hax::FullDefKind::Closure { args, .. } = def.kind() {
464 let upvar_tys = self.translate_closure_upvar_tys(span, args)?;
467 let upvar_tys = upvar_tys.replace_erased_regions(|| {
469 let region_id = self.the_only_binder_mut().push_upvar_region();
470 Region::Var(DeBruijnVar::new_at_zero(region_id))
471 });
472 self.the_only_binder_mut().closure_upvar_tys = Some(upvar_tys);
473
474 if let TransItemSourceKind::TraitImpl(TraitImplSource::Closure(..))
476 | TransItemSourceKind::ClosureMethod(..)
477 | TransItemSourceKind::ClosureAsFnCast = kind
478 {
479 self.the_only_binder_mut()
480 .push_params_from_binder(args.fn_sig.rebind(()))?;
481 }
482 if let TransItemSourceKind::ClosureMethod(ClosureKind::Fn | ClosureKind::FnMut) = kind {
483 let rid = self
485 .the_only_binder_mut()
486 .params
487 .regions
488 .push_with(|index| RegionParam::new(index, None));
489 self.the_only_binder_mut().closure_call_method_region = Some(rid);
490 }
491 }
492
493 self.innermost_binder_mut().params.check_consistency();
494 Ok(())
495 }
496
497 pub(crate) fn inside_binder<F, U>(
499 &mut self,
500 kind: BinderKind,
501 is_item_binder: bool,
502 f: F,
503 ) -> Result<Binder<U>, Error>
504 where
505 F: FnOnce(&mut Self) -> Result<U, Error>,
506 {
507 assert!(!self.binding_levels.is_empty());
508 self.binding_levels.push(BindingLevel::new(is_item_binder));
509
510 let res = f(self);
512
513 let params = self.binding_levels.pop().unwrap().params;
515
516 res.map(|skip_binder| Binder {
518 kind,
519 params,
520 skip_binder,
521 })
522 }
523
524 pub(crate) fn translate_binder_for_def<F, U>(
527 &mut self,
528 span: Span,
529 kind: BinderKind,
530 def: &hax::FullDef,
531 f: F,
532 ) -> Result<Binder<U>, Error>
533 where
534 F: FnOnce(&mut Self) -> Result<U, Error>,
535 {
536 let inner_hax_state = self.t_ctx.hax_state.clone().with_hax_owner(&def.def_id());
537 let outer_hax_state = mem::replace(&mut self.hax_state, inner_hax_state);
538 let ret = self.inside_binder(kind, true, |this| {
539 this.push_generics_for_def_without_parents(span, def)?;
540 this.push_late_bound_generics_for_def(span, def)?;
541 this.innermost_binder().params.check_consistency();
542 f(this)
543 });
544 self.hax_state = outer_hax_state;
545 ret
546 }
547
548 pub(crate) fn translate_region_binder<F, T, U>(
552 &mut self,
553 _span: Span,
554 binder: &hax::Binder<T>,
555 f: F,
556 ) -> Result<RegionBinder<U>, Error>
557 where
558 F: FnOnce(&mut Self, &T) -> Result<U, Error>,
559 {
560 let binder = self.inside_binder(BinderKind::Other, false, |this| {
561 this.innermost_binder_mut()
562 .push_params_from_binder(binder.rebind(()))?;
563 f(this, binder.hax_skip_binder_ref())
564 })?;
565 Ok(RegionBinder {
567 regions: binder.params.regions,
568 skip_binder: binder.skip_binder,
569 })
570 }
571
572 pub(crate) fn into_generics(mut self) -> GenericParams {
573 assert!(self.binding_levels.len() == 1);
574 self.binding_levels.pop().unwrap().params
575 }
576}
577
578#[derive(Default)]
580pub struct LifetimeMutabilityComputer {
581 lt_mutability: HashMap<hax::DefId, CycleDetector<HashSet<u32>>>,
582}
583
584impl LifetimeMutabilityComputer {
585 pub(crate) fn compute_lifetime_mutability<'tcx>(
587 &mut self,
588 s: &impl BaseState<'tcx>,
589 item: &hax::DefId,
590 index: u32,
591 ) -> LifetimeMutability {
592 match self.compute_lifetime_mutabilities(s, item) {
593 Some(set) => {
594 if set.contains(&index) {
595 LifetimeMutability::Mutable
596 } else {
597 LifetimeMutability::Shared
598 }
599 }
600 None => LifetimeMutability::Unknown,
601 }
602 }
603
604 fn compute_lifetime_mutabilities<'tcx>(
607 &mut self,
608 s: &impl BaseState<'tcx>,
609 item: &hax::DefId,
610 ) -> Option<&HashSet<u32>> {
611 if !matches!(
612 item.kind,
613 hax::DefKind::Struct | hax::DefKind::Enum | hax::DefKind::Union
614 ) {
615 return None;
616 }
617 if self
618 .lt_mutability
619 .entry(item.clone())
620 .or_default()
621 .start_processing()
622 {
623 use hax::SInto;
624 use ty::{TypeSuperVisitable, TypeVisitable};
625
626 struct LtMutabilityVisitor<'a, S> {
627 s: &'a S,
628 computer: &'a mut LifetimeMutabilityComputer,
629 set: HashSet<u32>,
630 }
631 impl<'tcx, S: BaseState<'tcx>> ty::TypeVisitor<ty::TyCtxt<'tcx>> for LtMutabilityVisitor<'_, S> {
632 fn visit_ty(&mut self, ty: ty::Ty<'tcx>) {
633 match ty.kind() {
634 ty::Ref(r, _, ty::Mutability::Mut)
635 if let ty::RegionKind::ReEarlyParam(r) = r.kind() =>
636 {
637 self.set.insert(r.index);
638 }
639 ty::Adt(adt, args) => {
640 let item = adt.did().sinto(self.s);
641 if let Some(mutabilities) =
642 self.computer.compute_lifetime_mutabilities(self.s, &item)
643 {
644 for arg in args.iter() {
645 if let Some(r) = arg.as_region()
646 && let ty::RegionKind::ReEarlyParam(r) = r.kind()
647 && mutabilities.contains(&r.index)
648 {
649 self.set.insert(r.index);
650 }
651 }
652 }
653 }
654 _ => {}
655 }
656 ty.super_visit_with(self)
657 }
658 }
659 let mut visitor = LtMutabilityVisitor {
660 s,
661 computer: self,
662 set: HashSet::new(),
663 };
664
665 let tcx = s.base().tcx;
666 let def_id = item.real_rust_def_id();
667 let adt_def = tcx.adt_def(def_id);
668 let generics = ty::GenericArgs::identity_for_item(tcx, def_id);
669 for variant in adt_def.variants() {
670 for field in &variant.fields {
671 field.ty(tcx, generics).visit_with(&mut visitor);
672 }
673 }
674 let set = visitor.set;
675
676 self.lt_mutability
677 .get_mut(item)
678 .unwrap()
679 .done_processing(set);
680 }
681 self.lt_mutability.get(item)?.as_processed()
682 }
683}