charon_lib/transform/normalize/
partial_monomorphization.rs1use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19
20use crate::ast::types_utils::TyVisitable;
21use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
22use crate::formatter::IntoFormatter;
23use crate::ids::IndexVec;
24use crate::options::MonomorphizeMut;
25use crate::pretty::FmtWithCtx;
26use crate::register_error;
27use crate::transform::ctx::TransformPass;
28use crate::{transform::TransformCtx, ullbc_ast::*};
29
30type MutabilityShape = Binder<GenericArgs>;
31
32#[derive(Visitor)]
34struct MutabilityShapeBuilder<'pm, 'ctx> {
35 pm: &'pm PartialMonomorphizer<'ctx>,
36 params: GenericParams,
38 extracted: GenericArgs,
40 binder_depth: DeBruijnId,
42}
43
44impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
45 fn compute_shape(
66 pm: &'pm PartialMonomorphizer<'ctx>,
67 target_params: &GenericParams,
68 args: &GenericArgs,
69 ) -> (MutabilityShape, GenericArgs) {
70 let mut shape_contents = args.clone();
75 let mut builder = Self {
76 pm,
77 params: GenericParams {
78 regions: IndexVec::new(),
79 types: IndexVec::new(),
80 const_generics: IndexVec::new(),
81 ..target_params.clone()
82 },
83 extracted: GenericArgs {
84 regions: IndexVec::new(),
85 types: IndexVec::new(),
86 const_generics: IndexVec::new(),
87 trait_refs: mem::take(&mut shape_contents.trait_refs),
88 },
89 binder_depth: DeBruijnId::zero(),
90 };
91
92 let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
95
96 let shape_params = {
97 let mut shape_params = builder.params;
98 shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
102 if i.index() < target_params.trait_clauses.len() {
103 x.substitute_explicits(&shape_contents)
104 } else {
105 x
106 }
107 });
108 shape_params.trait_type_constraints =
109 shape_params.trait_type_constraints.map_indexed(|i, x| {
110 if i.index() < target_params.trait_type_constraints.len() {
111 x.substitute_explicits(&shape_contents)
112 } else {
113 x
114 }
115 });
116 shape_params.regions_outlive = shape_params
117 .regions_outlive
118 .into_iter()
119 .enumerate()
120 .map(|(i, x)| {
121 if i < target_params.regions_outlive.len() {
122 x.substitute_explicits(&shape_contents)
123 } else {
124 x
125 }
126 })
127 .collect();
128 shape_params.types_outlive = shape_params
129 .types_outlive
130 .into_iter()
131 .enumerate()
132 .map(|(i, x)| {
133 if i < target_params.types_outlive.len() {
134 x.substitute_explicits(&shape_contents)
135 } else {
136 x
137 }
138 })
139 .collect();
140 shape_params
141 };
142
143 shape_contents.trait_refs = shape_params.identity_args().trait_refs;
146 shape_contents
147 .trait_refs
148 .truncate(target_params.trait_clauses.len());
149
150 let shape_args = builder.extracted;
151 let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
152 (shape, shape_args)
153 }
154
155 fn replace_with_fresh_var<Id, Param, Arg>(
157 &mut self,
158 val: &mut Arg,
159 mk_param: impl FnOnce(Id) -> Param,
160 mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
161 ) where
162 Id: Idx + Display,
163 Arg: TyVisitable + Clone,
164 GenericParams: HasIdxVecOf<Id, Output = Param>,
165 GenericArgs: HasIdxVecOf<Id, Output = Arg>,
166 {
167 let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
168 return;
170 };
171 self.extracted.get_idx_vec_mut().push(shifted_val);
173 let id = self.params.get_idx_vec_mut().push_with(mk_param);
175 *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
176 }
177}
178
179impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
180 fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
181 &mut self.binder_depth
182 }
183}
184
185impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
186 fn visit<T: AstVisitable>(&mut self, x: &mut T) -> ControlFlow<Self::Break> {
187 VisitWithBinderDepth::new(self).visit(x)
188 }
189
190 fn enter_ty(&mut self, ty: &mut Ty) {
191 if !self.pm.is_infected(ty) {
192 self.replace_with_fresh_var(
193 ty,
194 |id| TypeParam::new(id, format!("T{id}")),
195 |v| v.into(),
196 );
197 }
198 }
199 fn exit_ty_kind(&mut self, kind: &mut TyKind) {
200 if let TyKind::Adt(TypeDeclRef {
201 id: TypeId::Adt(id),
202 generics,
203 }) = kind
204 {
205 let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
210 return;
211 };
212 let Some(shifted_generics) =
213 generics.clone().move_from_under_binders(self.binder_depth)
214 else {
215 return;
217 };
218
219 let num_clauses_before_merge = self.params.trait_clauses.len();
221 self.params.merge_predicates_from(
222 target_params
223 .clone()
224 .substitute_explicits(&shifted_generics),
225 );
226
227 self.extracted
229 .trait_refs
230 .extend(shifted_generics.trait_refs);
231
232 for (target_clause_id, tref) in generics.trait_refs.iter_mut_enumerated() {
234 let clause_id = target_clause_id + num_clauses_before_merge;
235 *tref =
236 self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
237 }
238 }
239 }
240 fn enter_region(&mut self, r: &mut Region) {
241 self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
242 }
243 fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
250 ControlFlow::Continue(())
253 }
254
255 fn visit_constant_expr(
256 &mut self,
257 _: &mut ConstantExpr,
258 ) -> ::std::ops::ControlFlow<Self::Break> {
259 ControlFlow::Continue(())
260 }
261}
262
263#[derive(Visitor)]
264struct PartialMonomorphizer<'a> {
265 ctx: &'a mut TransformCtx,
266 span: Span,
268 specialize_adts: bool,
270 infected_types: HashSet<TypeDeclId>,
272 generic_params: HashMap<ItemId, GenericParams>,
277 partial_mono_shapes: SeqHashMap<(ItemId, MutabilityShape), ItemId>,
281 reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
283 to_process: VecDeque<ItemId>,
285}
286
287impl<'a> PartialMonomorphizer<'a> {
288 pub fn new(ctx: &'a mut TransformCtx, specialize_adts: bool) -> Self {
289 let infected_types: HashSet<_> = ctx
292 .translated
293 .type_decls
294 .iter()
295 .filter(|tdecl| {
296 tdecl
297 .generics
298 .regions
299 .iter()
300 .any(|r| r.mutability.is_mutable())
301 })
302 .map(|tdecl| tdecl.def_id)
303 .collect();
304
305 let generic_params: HashMap<ItemId, GenericParams> = ctx
307 .translated
308 .all_items()
309 .map(|item| (item.id(), item.generic_params().clone()))
310 .collect();
311
312 let to_process = ctx.translated.all_ids().collect();
314 PartialMonomorphizer {
315 ctx,
316 span: Span::dummy(),
317 specialize_adts,
318 infected_types,
319 generic_params,
320 to_process,
321 partial_mono_shapes: SeqHashMap::default(),
322 reverse_shape_map: Default::default(),
323 }
324 }
325
326 fn is_infected(&self, ty: &Ty) -> bool {
329 match ty.kind() {
330 TyKind::Ref(_, _, RefKind::Mut) => true,
331 TyKind::Ref(_, ty, _)
332 | TyKind::RawPtr(ty, _)
333 | TyKind::Array(ty, _)
334 | TyKind::Slice(ty) => self.is_infected(ty),
335 TyKind::Adt(tref) if let TypeId::Adt(id) = tref.id => {
336 let ty_infected = self.infected_types.contains(&id);
337 let args_infected = if self.specialize_adts {
338 false
343 } else {
344 tref.generics.types.iter().any(|ty| self.is_infected(ty))
345 };
346 ty_infected || args_infected
347 }
348 TyKind::Adt(..) => false,
349 TyKind::FnDef(..) | TyKind::FnPtr(..) => false,
353 TyKind::DynTrait(_) => {
354 register_error!(
355 self.ctx,
356 self.span,
357 "`dyn Trait` is unsupported with `--monomorphize-mut`"
358 );
359 false
360 }
361 TyKind::TypeVar(..)
362 | TyKind::Literal(..)
363 | TyKind::Never
364 | TyKind::TraitType(..)
365 | TyKind::PtrMetadata(..)
366 | TyKind::Error(_) => false,
367 }
368 }
369
370 fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
374 if !generics.types.iter().any(|ty| self.is_infected(ty)) {
375 return None;
376 }
377
378 let mut new_generics;
381 let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
382 new_generics = shape.clone().apply(generics);
383 let _ = self.visit(&mut new_generics); (base_id, &new_generics)
385 } else {
386 (id, generics)
387 };
388
389 let item_params = self.generic_params.get(&id)?;
391 let (shape, shape_args) =
392 MutabilityShapeBuilder::compute_shape(self, item_params, generics);
393
394 let new_params = shape.params.clone();
396 let key: (ItemId, MutabilityShape) = (id, shape);
397 let new_id = *self
398 .partial_mono_shapes
399 .entry(key.clone())
400 .or_insert_with(|| {
401 let new_id = match id {
402 ItemId::Type(_) => {
403 let new_id = self.ctx.translated.type_decls.reserve_slot();
404 self.infected_types.insert(new_id);
405 new_id.into()
406 }
407 ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
408 ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
409 ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
410 ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
411 };
412 self.generic_params.insert(new_id, new_params);
413 self.reverse_shape_map.insert(new_id, key);
414 self.to_process.push_back(new_id);
415 new_id
416 });
417
418 let fmt_ctx = self.ctx.into_fmt();
419 trace!(
420 "processing {}{}\n output: {}{}",
421 id.with_ctx(&fmt_ctx),
422 generics.with_ctx(&fmt_ctx),
423 new_id.with_ctx(&fmt_ctx),
424 shape_args.with_ctx(&fmt_ctx),
425 );
426 Some(DeclRef {
427 id: new_id,
428 generics: Box::new(shape_args),
429 trait_ref: None,
430 })
431 }
432
433 pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
437 let _ = item.drive_mut(self);
438 }
439
440 pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
445 let (orig_id, shape) = &self.reverse_shape_map[&new_id];
446 let mut decl = self
447 .ctx
448 .translated
449 .get_item(*orig_id)
450 .unwrap()
451 .to_owned()
452 .substitute_with_self(&shape.skip_binder, &TraitRefKind::SelfId);
453
454 let mut decl_mut = decl.as_mut();
455 decl_mut.set_id(new_id);
456 *decl_mut.generic_params() = shape.params.clone();
457
458 let name_ref = &mut decl_mut.item_meta().name;
459 *name_ref = mem::take::<crate::ast::Name>(name_ref).instantiate(shape.clone());
460 self.ctx
461 .translated
462 .item_names
463 .insert(new_id, decl.as_ref().item_meta().name.clone());
464
465 decl
466 }
467}
468
469impl VisitorWithSpan for PartialMonomorphizer<'_> {
470 fn current_span(&mut self) -> &mut Span {
471 &mut self.span
472 }
473}
474impl VisitAstMut for PartialMonomorphizer<'_> {
475 fn visit<T: AstVisitable>(&mut self, x: &mut T) -> ControlFlow<Self::Break> {
476 VisitWithSpan::new(self).visit(x)
478 }
479
480 fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
481 if self.specialize_adts
482 && let TypeId::Adt(id) = x.id
483 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
484 {
485 *x = new_decl_ref.try_into().unwrap()
486 }
487 }
488 fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
489 if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
492 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
493 {
494 *x = new_decl_ref.try_into().unwrap()
495 }
496 }
497 fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
498 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
499 *x = new_decl_ref.try_into().unwrap()
500 }
501 }
502 fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
503 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
504 *x = new_decl_ref.try_into().unwrap()
505 }
506 }
507 fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
508 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
509 *x = new_decl_ref.try_into().unwrap()
510 }
511 }
512 fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
513 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
514 *x = new_decl_ref.try_into().unwrap()
515 }
516 }
517}
518
519pub struct Transform;
520impl TransformPass for Transform {
521 fn transform_ctx(&self, ctx: &mut TransformCtx) {
522 let Some(include_types) = ctx.options.monomorphize_mut else {
523 return;
524 };
525 let mut visitor =
527 PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
528 while let Some(id) = visitor.to_process.pop_front() {
529 let mut decl = if visitor.reverse_shape_map.contains_key(&id) {
532 visitor.create_pending_instantiation(id)
534 } else {
535 match visitor.ctx.translated.remove_item_temporarily(id) {
538 Some(decl) => decl,
539 None => continue,
540 }
541 };
542 visitor.process_item(&mut decl.as_mut());
545 visitor.ctx.translated.set_item_slot(id, decl);
547 }
548 }
549}