charon_lib/transform/normalize/
partial_monomorphization.rs1use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19
20use crate::ast::types_utils::TyVisitable;
21use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
22use crate::formatter::IntoFormatter;
23use crate::options::MonomorphizeMut;
24use crate::pretty::FmtWithCtx;
25use crate::register_error;
26use crate::transform::ctx::TransformPass;
27use crate::{transform::TransformCtx, ullbc_ast::*};
28
29type MutabilityShape = Binder<GenericArgs>;
30
31#[derive(Visitor)]
33struct MutabilityShapeBuilder<'pm, 'ctx> {
34 pm: &'pm PartialMonomorphizer<'ctx>,
35 params: GenericParams,
37 extracted: GenericArgs,
39 binder_depth: DeBruijnId,
41}
42
43impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
44 fn compute_shape(
65 pm: &'pm PartialMonomorphizer<'ctx>,
66 target_params: &GenericParams,
67 args: &GenericArgs,
68 ) -> (MutabilityShape, GenericArgs) {
69 let mut shape_contents = args.clone();
74 let mut builder = Self {
75 pm,
76 params: GenericParams {
77 regions: IndexMap::new(),
78 types: IndexMap::new(),
79 const_generics: IndexMap::new(),
80 ..target_params.clone()
81 },
82 extracted: GenericArgs {
83 regions: IndexMap::new(),
84 types: IndexMap::new(),
85 const_generics: IndexMap::new(),
86 trait_refs: mem::take(&mut shape_contents.trait_refs),
87 },
88 binder_depth: DeBruijnId::zero(),
89 };
90
91 let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
94
95 let shape_params = {
96 let mut shape_params = builder.params;
97 shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
101 if i.index() < target_params.trait_clauses.slot_count() {
102 x.substitute_explicits(&shape_contents)
103 } else {
104 x
105 }
106 });
107 shape_params.trait_type_constraints =
108 shape_params.trait_type_constraints.map_indexed(|i, x| {
109 if i.index() < target_params.trait_type_constraints.slot_count() {
110 x.substitute_explicits(&shape_contents)
111 } else {
112 x
113 }
114 });
115 shape_params.regions_outlive = shape_params
116 .regions_outlive
117 .into_iter()
118 .enumerate()
119 .map(|(i, x)| {
120 if i < target_params.regions_outlive.len() {
121 x.substitute_explicits(&shape_contents)
122 } else {
123 x
124 }
125 })
126 .collect();
127 shape_params.types_outlive = shape_params
128 .types_outlive
129 .into_iter()
130 .enumerate()
131 .map(|(i, x)| {
132 if i < target_params.types_outlive.len() {
133 x.substitute_explicits(&shape_contents)
134 } else {
135 x
136 }
137 })
138 .collect();
139 shape_params
140 };
141
142 shape_contents.trait_refs = shape_params.identity_args().trait_refs;
145 shape_contents
146 .trait_refs
147 .truncate(target_params.trait_clauses.slot_count());
148
149 let shape_args = builder.extracted;
150 let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
151 (shape, shape_args)
152 }
153
154 fn replace_with_fresh_var<Id, Param, Arg>(
156 &mut self,
157 val: &mut Arg,
158 mk_param: impl FnOnce(Id) -> Param,
159 mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
160 ) where
161 Id: Idx + Display,
162 Arg: TyVisitable + Clone,
163 GenericParams: HasIdxMapOf<Id, Output = Param>,
164 GenericArgs: HasIdxMapOf<Id, Output = Arg>,
165 {
166 let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
167 return;
169 };
170 self.extracted.get_idx_map_mut().push(shifted_val);
172 let id = self.params.get_idx_map_mut().push_with(mk_param);
174 *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
175 }
176}
177
178impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
179 fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
180 &mut self.binder_depth
181 }
182}
183
184impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
185 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
186 VisitWithBinderDepth::new(self).visit(x)
187 }
188
189 fn enter_ty(&mut self, ty: &mut Ty) {
190 if !self.pm.is_infected(ty) {
191 self.replace_with_fresh_var(
192 ty,
193 |id| TypeParam::new(id, format!("T{id}")),
194 |v| v.into(),
195 );
196 }
197 }
198 fn exit_ty_kind(&mut self, kind: &mut TyKind) {
199 if let TyKind::Adt(TypeDeclRef {
200 id: TypeId::Adt(id),
201 generics,
202 }) = kind
203 {
204 let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
209 return;
210 };
211 let Some(shifted_generics) =
212 generics.clone().move_from_under_binders(self.binder_depth)
213 else {
214 return;
216 };
217
218 let num_clauses_before_merge = self.params.trait_clauses.slot_count();
220 self.params.merge_predicates_from(
221 target_params
222 .clone()
223 .substitute_explicits(&shifted_generics),
224 );
225
226 self.extracted
228 .trait_refs
229 .extend(shifted_generics.trait_refs);
230
231 for (target_clause_id, tref) in generics.trait_refs.iter_mut_indexed() {
233 let clause_id = target_clause_id + num_clauses_before_merge;
234 *tref =
235 self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
236 }
237 }
238 }
239 fn enter_region(&mut self, r: &mut Region) {
240 self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
241 }
242 fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
249 ControlFlow::Continue(())
252 }
253
254 fn visit_constant_expr(
255 &mut self,
256 _: &mut ConstantExpr,
257 ) -> ::std::ops::ControlFlow<Self::Break> {
258 ControlFlow::Continue(())
259 }
260}
261
262#[derive(Visitor)]
263struct PartialMonomorphizer<'a> {
264 ctx: &'a mut TransformCtx,
265 span: Span,
267 specialize_adts: bool,
269 infected_types: HashSet<TypeDeclId>,
271 generic_params: HashMap<ItemId, GenericParams>,
276 partial_mono_shapes: SeqHashMap<(ItemId, MutabilityShape), ItemId>,
280 reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
282 to_process: VecDeque<ItemId>,
284}
285
286impl<'a> PartialMonomorphizer<'a> {
287 pub fn new(ctx: &'a mut TransformCtx, specialize_adts: bool) -> Self {
288 let infected_types: HashSet<_> = ctx
291 .translated
292 .type_decls
293 .iter()
294 .filter(|tdecl| {
295 tdecl
296 .generics
297 .regions
298 .iter()
299 .any(|r| r.mutability.is_mutable())
300 })
301 .map(|tdecl| tdecl.def_id)
302 .collect();
303
304 let generic_params: HashMap<ItemId, GenericParams> = ctx
306 .translated
307 .all_items()
308 .map(|item| (item.id(), item.generic_params().clone()))
309 .collect();
310
311 let to_process = ctx.translated.all_ids().collect();
313 PartialMonomorphizer {
314 ctx,
315 span: Span::dummy(),
316 specialize_adts,
317 infected_types,
318 generic_params,
319 to_process,
320 partial_mono_shapes: SeqHashMap::default(),
321 reverse_shape_map: Default::default(),
322 }
323 }
324
325 fn is_infected(&self, ty: &Ty) -> bool {
328 match ty.kind() {
329 TyKind::Ref(_, _, RefKind::Mut) => true,
330 TyKind::Ref(_, ty, _)
331 | TyKind::RawPtr(ty, _)
332 | TyKind::Array(ty, _)
333 | TyKind::Slice(ty) => self.is_infected(ty),
334 TyKind::Adt(tref) if let TypeId::Adt(id) = tref.id => {
335 let ty_infected = self.infected_types.contains(&id);
336 let args_infected = if self.specialize_adts {
337 false
342 } else {
343 tref.generics.types.iter().any(|ty| self.is_infected(ty))
344 };
345 ty_infected || args_infected
346 }
347 TyKind::Adt(..) => false,
348 TyKind::FnDef(..) | TyKind::FnPtr(..) => false,
352 TyKind::DynTrait(_) => {
353 register_error!(
354 self.ctx,
355 self.span,
356 "`dyn Trait` is unsupported with `--monomorphize-mut`"
357 );
358 false
359 }
360 TyKind::TypeVar(..)
361 | TyKind::Literal(..)
362 | TyKind::Never
363 | TyKind::TraitType(..)
364 | TyKind::PtrMetadata(..)
365 | TyKind::Error(_) => false,
366 }
367 }
368
369 fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
373 if !generics.types.iter().any(|ty| self.is_infected(ty)) {
374 return None;
375 }
376
377 let mut new_generics;
380 let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
381 new_generics = shape.clone().apply(generics);
382 let _ = self.visit(&mut new_generics); (base_id, &new_generics)
384 } else {
385 (id, generics)
386 };
387
388 let item_params = self.generic_params.get(&id)?;
390 let (shape, shape_args) =
391 MutabilityShapeBuilder::compute_shape(self, item_params, generics);
392
393 let new_params = shape.params.clone();
395 let key: (ItemId, MutabilityShape) = (id, shape);
396 let new_id = *self
397 .partial_mono_shapes
398 .entry(key.clone())
399 .or_insert_with(|| {
400 let new_id = match id {
401 ItemId::Type(_) => {
402 let new_id = self.ctx.translated.type_decls.reserve_slot();
403 self.infected_types.insert(new_id);
404 new_id.into()
405 }
406 ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
407 ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
408 ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
409 ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
410 };
411 self.generic_params.insert(new_id, new_params);
412 self.reverse_shape_map.insert(new_id, key);
413 self.to_process.push_back(new_id);
414 new_id
415 });
416
417 let fmt_ctx = self.ctx.into_fmt();
418 trace!(
419 "processing {}{}\n output: {}{}",
420 id.with_ctx(&fmt_ctx),
421 generics.with_ctx(&fmt_ctx),
422 new_id.with_ctx(&fmt_ctx),
423 shape_args.with_ctx(&fmt_ctx),
424 );
425 Some(DeclRef {
426 id: new_id,
427 generics: Box::new(shape_args),
428 trait_ref: None,
429 })
430 }
431
432 pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
436 let _ = item.drive_mut(self);
437 }
438
439 pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
444 let (orig_id, shape) = &self.reverse_shape_map[&new_id];
445 let mut decl = self
446 .ctx
447 .translated
448 .get_item(*orig_id)
449 .unwrap()
450 .clone()
451 .substitute_with_self(&shape.skip_binder, &TraitRefKind::SelfId);
452
453 let mut decl_mut = decl.as_mut();
454 decl_mut.set_id(new_id);
455 *decl_mut.generic_params() = shape.params.clone();
456
457 let name_ref = &mut decl_mut.item_meta().name;
458 *name_ref = mem::take(name_ref).instantiate(shape.clone());
459 self.ctx
460 .translated
461 .item_names
462 .insert(new_id, decl.as_ref().item_meta().name.clone());
463
464 decl
465 }
466}
467
468impl VisitorWithSpan for PartialMonomorphizer<'_> {
469 fn current_span(&mut self) -> &mut Span {
470 &mut self.span
471 }
472}
473impl VisitAstMut for PartialMonomorphizer<'_> {
474 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
475 VisitWithSpan::new(self).visit(x)
477 }
478
479 fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
480 if self.specialize_adts
481 && let TypeId::Adt(id) = x.id
482 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
483 {
484 *x = new_decl_ref.try_into().unwrap()
485 }
486 }
487 fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
488 if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
491 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
492 {
493 *x = new_decl_ref.try_into().unwrap()
494 }
495 }
496 fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
497 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
498 *x = new_decl_ref.try_into().unwrap()
499 }
500 }
501 fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
502 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
503 *x = new_decl_ref.try_into().unwrap()
504 }
505 }
506 fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
507 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
508 *x = new_decl_ref.try_into().unwrap()
509 }
510 }
511 fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
512 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
513 *x = new_decl_ref.try_into().unwrap()
514 }
515 }
516}
517
518pub struct Transform;
519impl TransformPass for Transform {
520 fn transform_ctx(&self, ctx: &mut TransformCtx) {
521 let Some(include_types) = ctx.options.monomorphize_mut else {
522 return;
523 };
524 let mut visitor =
526 PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
527 while let Some(id) = visitor.to_process.pop_front() {
528 let mut decl = if visitor.reverse_shape_map.get(&id).is_some() {
531 visitor.create_pending_instantiation(id)
533 } else {
534 match visitor.ctx.translated.remove_item(id) {
537 Some(decl) => decl,
538 None => continue,
539 }
540 };
541 visitor.process_item(&mut decl.as_mut());
544 visitor.ctx.translated.set_item_slot(id, decl);
546 }
547 }
548}