charon_lib/transform/normalize/
partial_monomorphization.rs1use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19use indexmap::IndexMap;
20
21use crate::ast::types_utils::TyVisitable;
22use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
23use crate::formatter::IntoFormatter;
24use crate::options::MonomorphizeMut;
25use crate::pretty::FmtWithCtx;
26use crate::register_error;
27use crate::transform::ctx::TransformPass;
28use crate::{transform::TransformCtx, ullbc_ast::*};
29
30type MutabilityShape = Binder<GenericArgs>;
31
32#[derive(Visitor)]
34struct MutabilityShapeBuilder<'pm, 'ctx> {
35 pm: &'pm PartialMonomorphizer<'ctx>,
36 params: GenericParams,
38 extracted: GenericArgs,
40 binder_depth: DeBruijnId,
42}
43
44impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
45 fn compute_shape(
66 pm: &'pm PartialMonomorphizer<'ctx>,
67 target_params: &GenericParams,
68 args: &GenericArgs,
69 ) -> (MutabilityShape, GenericArgs) {
70 let mut shape_contents = args.clone();
75 let mut builder = Self {
76 pm,
77 params: GenericParams {
78 regions: Vector::new(),
79 types: Vector::new(),
80 const_generics: Vector::new(),
81 ..target_params.clone()
82 },
83 extracted: GenericArgs {
84 regions: Vector::new(),
85 types: Vector::new(),
86 const_generics: Vector::new(),
87 trait_refs: mem::take(&mut shape_contents.trait_refs),
88 },
89 binder_depth: DeBruijnId::zero(),
90 };
91
92 let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
95
96 let shape_params = {
97 let mut shape_params = builder.params;
98 shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
102 if i.index() < target_params.trait_clauses.slot_count() {
103 x.substitute_explicits(&shape_contents)
104 } else {
105 x
106 }
107 });
108 shape_params.trait_type_constraints =
109 shape_params.trait_type_constraints.map_indexed(|i, x| {
110 if i.index() < target_params.trait_type_constraints.slot_count() {
111 x.substitute_explicits(&shape_contents)
112 } else {
113 x
114 }
115 });
116 shape_params.regions_outlive = shape_params
117 .regions_outlive
118 .into_iter()
119 .enumerate()
120 .map(|(i, x)| {
121 if i < target_params.regions_outlive.len() {
122 x.substitute_explicits(&shape_contents)
123 } else {
124 x
125 }
126 })
127 .collect();
128 shape_params.types_outlive = shape_params
129 .types_outlive
130 .into_iter()
131 .enumerate()
132 .map(|(i, x)| {
133 if i < target_params.types_outlive.len() {
134 x.substitute_explicits(&shape_contents)
135 } else {
136 x
137 }
138 })
139 .collect();
140 shape_params
141 };
142
143 shape_contents.trait_refs = shape_params.identity_args().trait_refs;
146 shape_contents
147 .trait_refs
148 .truncate(target_params.trait_clauses.slot_count());
149
150 let shape_args = builder.extracted;
151 let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
152 (shape, shape_args)
153 }
154
155 fn replace_with_fresh_var<Id, Param, Arg>(
157 &mut self,
158 val: &mut Arg,
159 mk_param: impl FnOnce(Id) -> Param,
160 mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
161 ) where
162 Id: Idx + Display,
163 Arg: TyVisitable + Clone,
164 GenericParams: HasVectorOf<Id, Output = Param>,
165 GenericArgs: HasVectorOf<Id, Output = Arg>,
166 {
167 let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
168 return;
170 };
171 self.extracted.get_vector_mut().push(shifted_val);
173 let id = self.params.get_vector_mut().push_with(mk_param);
175 *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
176 }
177}
178
179impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
180 fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
181 &mut self.binder_depth
182 }
183}
184
185impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
186 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
187 VisitWithBinderDepth::new(self).visit(x)
188 }
189
190 fn enter_ty(&mut self, ty: &mut Ty) {
191 if !self.pm.is_infected(ty) {
192 self.replace_with_fresh_var(
193 ty,
194 |id| TypeParam::new(id, format!("T{id}")),
195 |v| v.into(),
196 );
197 }
198 }
199 fn exit_ty_kind(&mut self, kind: &mut TyKind) {
200 if let TyKind::Adt(TypeDeclRef {
201 id: TypeId::Adt(id),
202 generics,
203 }) = kind
204 {
205 let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
210 return;
211 };
212 let Some(shifted_generics) =
213 generics.clone().move_from_under_binders(self.binder_depth)
214 else {
215 return;
217 };
218
219 let num_clauses_before_merge = self.params.trait_clauses.slot_count();
221 self.params.merge_predicates_from(
222 target_params
223 .clone()
224 .substitute_explicits(&shifted_generics),
225 );
226
227 self.extracted
229 .trait_refs
230 .extend(shifted_generics.trait_refs);
231
232 for (target_clause_id, tref) in generics.trait_refs.iter_mut_indexed() {
234 let clause_id = target_clause_id + num_clauses_before_merge;
235 *tref =
236 self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
237 }
238 }
239 }
240 fn enter_region(&mut self, r: &mut Region) {
241 self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
242 }
243 fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
250 ControlFlow::Continue(())
253 }
254}
255
256#[derive(Visitor)]
257struct PartialMonomorphizer<'a> {
258 ctx: &'a mut TransformCtx,
259 span: Span,
261 instantiate_types: bool,
263 infected_types: HashSet<TypeDeclId>,
265 generic_params: HashMap<ItemId, GenericParams>,
270 partial_mono_shapes: IndexMap<(ItemId, MutabilityShape), ItemId>,
274 reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
276 to_process: VecDeque<ItemId>,
278}
279
280impl<'a> PartialMonomorphizer<'a> {
281 pub fn new(ctx: &'a mut TransformCtx, instantiate_types: bool) -> Self {
282 use petgraph::graphmap::DiGraphMap;
283 use petgraph::visit::Dfs;
284 use petgraph::visit::Walker;
285
286 let infected_types: HashSet<_> = {
288 let mut graph: DiGraphMap<Option<TypeDeclId>, ()> = Default::default();
293 for (id, tdecl) in ctx.translated.type_decls.iter_indexed() {
294 tdecl.dyn_visit(|x: &Ty| match x.kind() {
295 TyKind::Ref(_, _, RefKind::Mut) => {
296 graph.add_edge(None, Some(id), ());
297 }
298 TyKind::Adt(tref) if let TypeId::Adt(other_id) = tref.id => {
299 graph.add_edge(Some(other_id), Some(id), ());
300 }
301 _ => {}
302 });
303 }
304 let start = graph.add_node(None);
305 Dfs::new(&graph, start)
306 .iter(&graph)
307 .filter_map(|opt_id| opt_id)
308 .collect()
309 };
310
311 let generic_params: HashMap<ItemId, GenericParams> = ctx
313 .translated
314 .all_items()
315 .map(|item| (item.id(), item.generic_params().clone()))
316 .collect();
317
318 let to_process = ctx.translated.all_ids().collect();
320 PartialMonomorphizer {
321 ctx,
322 span: Span::dummy(),
323 instantiate_types,
324 infected_types,
325 generic_params,
326 to_process,
327 partial_mono_shapes: IndexMap::default(),
328 reverse_shape_map: Default::default(),
329 }
330 }
331
332 fn is_infected(&self, ty: &Ty) -> bool {
335 match ty.kind() {
336 TyKind::Ref(_, _, RefKind::Mut) => true,
337 TyKind::Ref(_, ty, _) | TyKind::RawPtr(ty, _) => self.is_infected(ty),
338 TyKind::Adt(tref) => {
339 let ty_infected =
340 matches!(&tref.id, TypeId::Adt(id) if self.infected_types.contains(id));
341 let args_infected = if tref.id.is_adt() && self.instantiate_types {
342 false
347 } else {
348 tref.generics.types.iter().any(|ty| self.is_infected(ty))
349 };
350 ty_infected || args_infected
351 }
352 TyKind::FnDef(..) | TyKind::FnPtr(..) => {
353 register_error!(
354 self.ctx,
355 self.span,
356 "function pointers are unsupported with `--monomorphize-mut`"
357 );
358 false
359 }
360 TyKind::DynTrait(_) => {
361 register_error!(
362 self.ctx,
363 self.span,
364 "`dyn Trait` is unsupported with `--monomorphize-mut`"
365 );
366 false
367 }
368 TyKind::TypeVar(..)
369 | TyKind::Literal(..)
370 | TyKind::Never
371 | TyKind::TraitType(..)
372 | TyKind::PtrMetadata(..)
373 | TyKind::Error(_) => false,
374 }
375 }
376
377 fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
381 if !generics.types.iter().any(|ty| self.is_infected(ty)) {
382 return None;
383 }
384
385 let mut new_generics;
388 let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
389 new_generics = shape.clone().apply(generics);
390 let _ = self.visit(&mut new_generics); (base_id, &new_generics)
392 } else {
393 (id, generics)
394 };
395
396 let item_params = self.generic_params.get(&id)?;
398 let (shape, shape_args) =
399 MutabilityShapeBuilder::compute_shape(self, item_params, generics);
400
401 let new_params = shape.params.clone();
403 let key: (ItemId, MutabilityShape) = (id, shape);
404 let new_id = *self
405 .partial_mono_shapes
406 .entry(key.clone())
407 .or_insert_with(|| {
408 let new_id = match id {
409 ItemId::Type(_) => {
410 let new_id = self.ctx.translated.type_decls.reserve_slot();
411 self.infected_types.insert(new_id);
412 new_id.into()
413 }
414 ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
415 ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
416 ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
417 ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
418 };
419 self.generic_params.insert(new_id, new_params);
420 self.reverse_shape_map.insert(new_id, key);
421 self.to_process.push_back(new_id);
422 new_id
423 });
424
425 let fmt_ctx = self.ctx.into_fmt();
426 trace!(
427 "processing {}{}\n output: {}{}",
428 id.with_ctx(&fmt_ctx),
429 generics.with_ctx(&fmt_ctx),
430 new_id.with_ctx(&fmt_ctx),
431 shape_args.with_ctx(&fmt_ctx),
432 );
433 Some(DeclRef {
434 id: new_id,
435 generics: Box::new(shape_args),
436 })
437 }
438
439 pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
443 let _ = item.drive_mut(self);
444 }
445
446 pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
451 let (orig_id, shape) = &self.reverse_shape_map[&new_id];
452 let mut decl = self
453 .ctx
454 .translated
455 .get_item(*orig_id)
456 .unwrap()
457 .clone()
458 .substitute(&shape.skip_binder);
459
460 let mut decl_mut = decl.as_mut();
461 decl_mut.set_id(new_id);
462 *decl_mut.generic_params() = shape.params.clone();
463
464 let name_ref = &mut decl_mut.item_meta().name;
465 *name_ref = mem::take(name_ref).instantiate(shape.clone());
466 self.ctx
467 .translated
468 .item_names
469 .insert(new_id, decl.as_ref().item_meta().name.clone());
470
471 decl
472 }
473}
474
475impl VisitorWithSpan for PartialMonomorphizer<'_> {
476 fn current_span(&mut self) -> &mut Span {
477 &mut self.span
478 }
479}
480impl VisitAstMut for PartialMonomorphizer<'_> {
481 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
482 VisitWithSpan::new(self).visit(x)
484 }
485
486 fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
487 if self.instantiate_types
488 && let TypeId::Adt(id) = x.id
489 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
490 {
491 *x = new_decl_ref.try_into().unwrap()
492 }
493 }
494 fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
495 if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
498 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
499 {
500 *x = new_decl_ref.try_into().unwrap()
501 }
502 }
503 fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
504 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
505 *x = new_decl_ref.try_into().unwrap()
506 }
507 }
508 fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
509 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
510 *x = new_decl_ref.try_into().unwrap()
511 }
512 }
513 fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
514 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
515 *x = new_decl_ref.try_into().unwrap()
516 }
517 }
518 fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
519 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
520 *x = new_decl_ref.try_into().unwrap()
521 }
522 }
523}
524
525pub struct Transform;
526impl TransformPass for Transform {
527 fn transform_ctx(&self, ctx: &mut TransformCtx) {
528 let Some(include_types) = ctx.options.monomorphize_mut else {
529 return;
530 };
531 let mut visitor =
533 PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
534 while let Some(id) = visitor.to_process.pop_front() {
535 let mut decl = if visitor.reverse_shape_map.get(&id).is_some() {
538 visitor.create_pending_instantiation(id)
540 } else {
541 match visitor.ctx.translated.remove_item(id) {
544 Some(decl) => decl,
545 None => continue,
546 }
547 };
548 visitor.process_item(&mut decl.as_mut());
551 visitor.ctx.translated.set_item_slot(id, decl);
553 }
554 }
555}