1use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19
20use crate::ast::types_utils::TyVisitable;
21use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
22use crate::formatter::IntoFormatter;
23use crate::options::MonomorphizeMut;
24use crate::pretty::FmtWithCtx;
25use crate::register_error;
26use crate::transform::ctx::TransformPass;
27use crate::{transform::TransformCtx, ullbc_ast::*};
28
29type MutabilityShape = Binder<GenericArgs>;
30
31#[derive(Visitor)]
33struct MutabilityShapeBuilder<'pm, 'ctx> {
34 pm: &'pm PartialMonomorphizer<'ctx>,
35 params: GenericParams,
37 extracted: GenericArgs,
39 binder_depth: DeBruijnId,
41}
42
43impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
44 fn compute_shape(
65 pm: &'pm PartialMonomorphizer<'ctx>,
66 target_params: &GenericParams,
67 args: &GenericArgs,
68 ) -> (MutabilityShape, GenericArgs) {
69 let mut shape_contents = args.clone();
74 let mut builder = Self {
75 pm,
76 params: GenericParams {
77 regions: IndexMap::new(),
78 types: IndexMap::new(),
79 const_generics: IndexMap::new(),
80 ..target_params.clone()
81 },
82 extracted: GenericArgs {
83 regions: IndexMap::new(),
84 types: IndexMap::new(),
85 const_generics: IndexMap::new(),
86 trait_refs: mem::take(&mut shape_contents.trait_refs),
87 },
88 binder_depth: DeBruijnId::zero(),
89 };
90
91 let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
94
95 let shape_params = {
96 let mut shape_params = builder.params;
97 shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
101 if i.index() < target_params.trait_clauses.slot_count() {
102 x.substitute_explicits(&shape_contents)
103 } else {
104 x
105 }
106 });
107 shape_params.trait_type_constraints =
108 shape_params.trait_type_constraints.map_indexed(|i, x| {
109 if i.index() < target_params.trait_type_constraints.slot_count() {
110 x.substitute_explicits(&shape_contents)
111 } else {
112 x
113 }
114 });
115 shape_params.regions_outlive = shape_params
116 .regions_outlive
117 .into_iter()
118 .enumerate()
119 .map(|(i, x)| {
120 if i < target_params.regions_outlive.len() {
121 x.substitute_explicits(&shape_contents)
122 } else {
123 x
124 }
125 })
126 .collect();
127 shape_params.types_outlive = shape_params
128 .types_outlive
129 .into_iter()
130 .enumerate()
131 .map(|(i, x)| {
132 if i < target_params.types_outlive.len() {
133 x.substitute_explicits(&shape_contents)
134 } else {
135 x
136 }
137 })
138 .collect();
139 shape_params
140 };
141
142 shape_contents.trait_refs = shape_params.identity_args().trait_refs;
145 shape_contents
146 .trait_refs
147 .truncate(target_params.trait_clauses.slot_count());
148
149 let shape_args = builder.extracted;
150 let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
151 (shape, shape_args)
152 }
153
154 fn replace_with_fresh_var<Id, Param, Arg>(
156 &mut self,
157 val: &mut Arg,
158 mk_param: impl FnOnce(Id) -> Param,
159 mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
160 ) where
161 Id: Idx + Display,
162 Arg: TyVisitable + Clone,
163 GenericParams: HasIdxMapOf<Id, Output = Param>,
164 GenericArgs: HasIdxMapOf<Id, Output = Arg>,
165 {
166 let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
167 return;
169 };
170 self.extracted.get_idx_map_mut().push(shifted_val);
172 let id = self.params.get_idx_map_mut().push_with(mk_param);
174 *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
175 }
176}
177
178impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
179 fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
180 &mut self.binder_depth
181 }
182}
183
184impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
185 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
186 VisitWithBinderDepth::new(self).visit(x)
187 }
188
189 fn enter_ty(&mut self, ty: &mut Ty) {
190 if !self.pm.is_infected(ty) {
191 self.replace_with_fresh_var(
192 ty,
193 |id| TypeParam::new(id, format!("T{id}")),
194 |v| v.into(),
195 );
196 }
197 }
198 fn exit_ty_kind(&mut self, kind: &mut TyKind) {
199 if let TyKind::Adt(TypeDeclRef {
200 id: TypeId::Adt(id),
201 generics,
202 }) = kind
203 {
204 let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
209 return;
210 };
211 let Some(shifted_generics) =
212 generics.clone().move_from_under_binders(self.binder_depth)
213 else {
214 return;
216 };
217
218 let num_clauses_before_merge = self.params.trait_clauses.slot_count();
220 self.params.merge_predicates_from(
221 target_params
222 .clone()
223 .substitute_explicits(&shifted_generics),
224 );
225
226 self.extracted
228 .trait_refs
229 .extend(shifted_generics.trait_refs);
230
231 for (target_clause_id, tref) in generics.trait_refs.iter_mut_indexed() {
233 let clause_id = target_clause_id + num_clauses_before_merge;
234 *tref =
235 self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
236 }
237 }
238 }
239 fn enter_region(&mut self, r: &mut Region) {
240 self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
241 }
242 fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
249 ControlFlow::Continue(())
252 }
253}
254
255#[derive(Visitor)]
256struct PartialMonomorphizer<'a> {
257 ctx: &'a mut TransformCtx,
258 span: Span,
260 instantiate_types: bool,
262 infected_types: HashSet<TypeDeclId>,
264 generic_params: HashMap<ItemId, GenericParams>,
269 partial_mono_shapes: SeqHashMap<(ItemId, MutabilityShape), ItemId>,
273 reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
275 to_process: VecDeque<ItemId>,
277}
278
279impl<'a> PartialMonomorphizer<'a> {
280 pub fn new(ctx: &'a mut TransformCtx, instantiate_types: bool) -> Self {
281 use petgraph::graphmap::DiGraphMap;
282 use petgraph::visit::Dfs;
283 use petgraph::visit::Walker;
284
285 let infected_types: HashSet<_> = {
287 let mut graph: DiGraphMap<Option<TypeDeclId>, ()> = Default::default();
292 for (id, tdecl) in ctx.translated.type_decls.iter_indexed() {
293 tdecl.dyn_visit(|x: &Ty| match x.kind() {
294 TyKind::Ref(_, _, RefKind::Mut) => {
295 graph.add_edge(None, Some(id), ());
296 }
297 TyKind::Adt(tref) if let TypeId::Adt(other_id) = tref.id => {
298 graph.add_edge(Some(other_id), Some(id), ());
299 }
300 _ => {}
301 });
302 }
303 let start = graph.add_node(None);
304 Dfs::new(&graph, start)
305 .iter(&graph)
306 .filter_map(|opt_id| opt_id)
307 .collect()
308 };
309
310 let generic_params: HashMap<ItemId, GenericParams> = ctx
312 .translated
313 .all_items()
314 .map(|item| (item.id(), item.generic_params().clone()))
315 .collect();
316
317 let to_process = ctx.translated.all_ids().collect();
319 PartialMonomorphizer {
320 ctx,
321 span: Span::dummy(),
322 instantiate_types,
323 infected_types,
324 generic_params,
325 to_process,
326 partial_mono_shapes: SeqHashMap::default(),
327 reverse_shape_map: Default::default(),
328 }
329 }
330
331 fn is_infected(&self, ty: &Ty) -> bool {
334 match ty.kind() {
335 TyKind::Ref(_, _, RefKind::Mut) => true,
336 TyKind::Ref(_, ty, _) | TyKind::RawPtr(ty, _) => self.is_infected(ty),
337 TyKind::Adt(tref) => {
338 let ty_infected =
339 matches!(&tref.id, TypeId::Adt(id) if self.infected_types.contains(id));
340 let args_infected = if tref.id.is_adt() && self.instantiate_types {
341 false
346 } else {
347 tref.generics.types.iter().any(|ty| self.is_infected(ty))
348 };
349 ty_infected || args_infected
350 }
351 TyKind::FnDef(..) | TyKind::FnPtr(..) => {
352 register_error!(
353 self.ctx,
354 self.span,
355 "function pointers are unsupported with `--monomorphize-mut`"
356 );
357 false
358 }
359 TyKind::DynTrait(_) => {
360 register_error!(
361 self.ctx,
362 self.span,
363 "`dyn Trait` is unsupported with `--monomorphize-mut`"
364 );
365 false
366 }
367 TyKind::TypeVar(..)
368 | TyKind::Literal(..)
369 | TyKind::Never
370 | TyKind::TraitType(..)
371 | TyKind::PtrMetadata(..)
372 | TyKind::Error(_) => false,
373 }
374 }
375
376 fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
380 if !generics.types.iter().any(|ty| self.is_infected(ty)) {
381 return None;
382 }
383
384 let mut new_generics;
387 let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
388 new_generics = shape.clone().apply(generics);
389 let _ = self.visit(&mut new_generics); (base_id, &new_generics)
391 } else {
392 (id, generics)
393 };
394
395 let item_params = self.generic_params.get(&id)?;
397 let (shape, shape_args) =
398 MutabilityShapeBuilder::compute_shape(self, item_params, generics);
399
400 let new_params = shape.params.clone();
402 let key: (ItemId, MutabilityShape) = (id, shape);
403 let new_id = *self
404 .partial_mono_shapes
405 .entry(key.clone())
406 .or_insert_with(|| {
407 let new_id = match id {
408 ItemId::Type(_) => {
409 let new_id = self.ctx.translated.type_decls.reserve_slot();
410 self.infected_types.insert(new_id);
411 new_id.into()
412 }
413 ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
414 ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
415 ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
416 ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
417 };
418 self.generic_params.insert(new_id, new_params);
419 self.reverse_shape_map.insert(new_id, key);
420 self.to_process.push_back(new_id);
421 new_id
422 });
423
424 let fmt_ctx = self.ctx.into_fmt();
425 trace!(
426 "processing {}{}\n output: {}{}",
427 id.with_ctx(&fmt_ctx),
428 generics.with_ctx(&fmt_ctx),
429 new_id.with_ctx(&fmt_ctx),
430 shape_args.with_ctx(&fmt_ctx),
431 );
432 Some(DeclRef {
433 id: new_id,
434 generics: Box::new(shape_args),
435 trait_ref: None,
436 })
437 }
438
439 pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
443 let _ = item.drive_mut(self);
444 }
445
446 pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
451 let (orig_id, shape) = &self.reverse_shape_map[&new_id];
452 let mut decl = self
453 .ctx
454 .translated
455 .get_item(*orig_id)
456 .unwrap()
457 .clone()
458 .substitute(&shape.skip_binder);
459
460 let mut decl_mut = decl.as_mut();
461 decl_mut.set_id(new_id);
462 *decl_mut.generic_params() = shape.params.clone();
463
464 let name_ref = &mut decl_mut.item_meta().name;
465 *name_ref = mem::take(name_ref).instantiate(shape.clone());
466 self.ctx
467 .translated
468 .item_names
469 .insert(new_id, decl.as_ref().item_meta().name.clone());
470
471 decl
472 }
473}
474
475impl VisitorWithSpan for PartialMonomorphizer<'_> {
476 fn current_span(&mut self) -> &mut Span {
477 &mut self.span
478 }
479}
480impl VisitAstMut for PartialMonomorphizer<'_> {
481 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
482 VisitWithSpan::new(self).visit(x)
484 }
485
486 fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
487 if self.instantiate_types
488 && let TypeId::Adt(id) = x.id
489 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
490 {
491 *x = new_decl_ref.try_into().unwrap()
492 }
493 }
494 fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
495 if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
498 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
499 {
500 *x = new_decl_ref.try_into().unwrap()
501 }
502 }
503 fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
504 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
505 *x = new_decl_ref.try_into().unwrap()
506 }
507 }
508 fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
509 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
510 *x = new_decl_ref.try_into().unwrap()
511 }
512 }
513 fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
514 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
515 *x = new_decl_ref.try_into().unwrap()
516 }
517 }
518 fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
519 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
520 *x = new_decl_ref.try_into().unwrap()
521 }
522 }
523}
524
525pub struct Transform;
526impl TransformPass for Transform {
527 fn transform_ctx(&self, ctx: &mut TransformCtx) {
528 let Some(include_types) = ctx.options.monomorphize_mut else {
529 return;
530 };
531 let mut visitor =
533 PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
534 while let Some(id) = visitor.to_process.pop_front() {
535 let mut decl = if visitor.reverse_shape_map.get(&id).is_some() {
538 visitor.create_pending_instantiation(id)
540 } else {
541 match visitor.ctx.translated.remove_item(id) {
544 Some(decl) => decl,
545 None => continue,
546 }
547 };
548 visitor.process_item(&mut decl.as_mut());
551 visitor.ctx.translated.set_item_slot(id, decl);
553 }
554 }
555}