1use std::collections::{HashMap, HashSet, VecDeque};
14use std::fmt::Display;
15use std::mem;
16
17use derive_generic_visitor::Visitor;
18use index_vec::Idx;
19
20use crate::ast::types_utils::TyVisitable;
21use crate::ast::visitor::{VisitWithBinderDepth, VisitorWithBinderDepth};
22use crate::formatter::IntoFormatter;
23use crate::options::MonomorphizeMut;
24use crate::pretty::FmtWithCtx;
25use crate::register_error;
26use crate::transform::ctx::TransformPass;
27use crate::{transform::TransformCtx, ullbc_ast::*};
28
29type MutabilityShape = Binder<GenericArgs>;
30
31#[derive(Visitor)]
33struct MutabilityShapeBuilder<'pm, 'ctx> {
34 pm: &'pm PartialMonomorphizer<'ctx>,
35 params: GenericParams,
37 extracted: GenericArgs,
39 binder_depth: DeBruijnId,
41}
42
43impl<'pm, 'ctx> MutabilityShapeBuilder<'pm, 'ctx> {
44 fn compute_shape(
65 pm: &'pm PartialMonomorphizer<'ctx>,
66 target_params: &GenericParams,
67 args: &GenericArgs,
68 ) -> (MutabilityShape, GenericArgs) {
69 let mut shape_contents = args.clone();
74 let mut builder = Self {
75 pm,
76 params: GenericParams {
77 regions: IndexMap::new(),
78 types: IndexMap::new(),
79 const_generics: IndexMap::new(),
80 ..target_params.clone()
81 },
82 extracted: GenericArgs {
83 regions: IndexMap::new(),
84 types: IndexMap::new(),
85 const_generics: IndexMap::new(),
86 trait_refs: mem::take(&mut shape_contents.trait_refs),
87 },
88 binder_depth: DeBruijnId::zero(),
89 };
90
91 let _ = VisitWithBinderDepth::new(&mut builder).visit(&mut shape_contents);
94
95 let shape_params = {
96 let mut shape_params = builder.params;
97 shape_params.trait_clauses = shape_params.trait_clauses.map_indexed(|i, x| {
101 if i.index() < target_params.trait_clauses.slot_count() {
102 x.substitute_explicits(&shape_contents)
103 } else {
104 x
105 }
106 });
107 shape_params.trait_type_constraints =
108 shape_params.trait_type_constraints.map_indexed(|i, x| {
109 if i.index() < target_params.trait_type_constraints.slot_count() {
110 x.substitute_explicits(&shape_contents)
111 } else {
112 x
113 }
114 });
115 shape_params.regions_outlive = shape_params
116 .regions_outlive
117 .into_iter()
118 .enumerate()
119 .map(|(i, x)| {
120 if i < target_params.regions_outlive.len() {
121 x.substitute_explicits(&shape_contents)
122 } else {
123 x
124 }
125 })
126 .collect();
127 shape_params.types_outlive = shape_params
128 .types_outlive
129 .into_iter()
130 .enumerate()
131 .map(|(i, x)| {
132 if i < target_params.types_outlive.len() {
133 x.substitute_explicits(&shape_contents)
134 } else {
135 x
136 }
137 })
138 .collect();
139 shape_params
140 };
141
142 shape_contents.trait_refs = shape_params.identity_args().trait_refs;
145 shape_contents
146 .trait_refs
147 .truncate(target_params.trait_clauses.slot_count());
148
149 let shape_args = builder.extracted;
150 let shape = Binder::new(BinderKind::Other, shape_params, shape_contents);
151 (shape, shape_args)
152 }
153
154 fn replace_with_fresh_var<Id, Param, Arg>(
156 &mut self,
157 val: &mut Arg,
158 mk_param: impl FnOnce(Id) -> Param,
159 mk_value: impl FnOnce(DeBruijnVar<Id>) -> Arg,
160 ) where
161 Id: Idx + Display,
162 Arg: TyVisitable + Clone,
163 GenericParams: HasIdxMapOf<Id, Output = Param>,
164 GenericArgs: HasIdxMapOf<Id, Output = Arg>,
165 {
166 let Some(shifted_val) = val.clone().move_from_under_binders(self.binder_depth) else {
167 return;
169 };
170 self.extracted.get_idx_map_mut().push(shifted_val);
172 let id = self.params.get_idx_map_mut().push_with(mk_param);
174 *val = mk_value(DeBruijnVar::bound(self.binder_depth, id));
175 }
176}
177
178impl<'pm, 'ctx> VisitorWithBinderDepth for MutabilityShapeBuilder<'pm, 'ctx> {
179 fn binder_depth_mut(&mut self) -> &mut DeBruijnId {
180 &mut self.binder_depth
181 }
182}
183
184impl<'pm, 'ctx> VisitAstMut for MutabilityShapeBuilder<'pm, 'ctx> {
185 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
186 VisitWithBinderDepth::new(self).visit(x)
187 }
188
189 fn enter_ty(&mut self, ty: &mut Ty) {
190 if !self.pm.is_infected(ty) {
191 self.replace_with_fresh_var(
192 ty,
193 |id| TypeParam::new(id, format!("T{id}")),
194 |v| v.into(),
195 );
196 }
197 }
198 fn exit_ty_kind(&mut self, kind: &mut TyKind) {
199 if let TyKind::Adt(TypeDeclRef {
200 id: TypeId::Adt(id),
201 generics,
202 }) = kind
203 {
204 let Some(target_params) = self.pm.generic_params.get(&(*id).into()) else {
209 return;
210 };
211 let Some(shifted_generics) =
212 generics.clone().move_from_under_binders(self.binder_depth)
213 else {
214 return;
216 };
217
218 let num_clauses_before_merge = self.params.trait_clauses.slot_count();
220 self.params.merge_predicates_from(
221 target_params
222 .clone()
223 .substitute_explicits(&shifted_generics),
224 );
225
226 self.extracted
228 .trait_refs
229 .extend(shifted_generics.trait_refs);
230
231 for (target_clause_id, tref) in generics.trait_refs.iter_mut_indexed() {
233 let clause_id = target_clause_id + num_clauses_before_merge;
234 *tref =
235 self.params.trait_clauses[clause_id].identity_tref_at_depth(self.binder_depth);
236 }
237 }
238 }
239 fn enter_region(&mut self, r: &mut Region) {
240 self.replace_with_fresh_var(r, |id| RegionParam::new(id, None), |v| v.into());
241 }
242 fn visit_trait_ref(&mut self, _tref: &mut TraitRef) -> ControlFlow<Self::Break> {
249 ControlFlow::Continue(())
252 }
253
254 fn visit_constant_expr(
255 &mut self,
256 _: &mut ConstantExpr,
257 ) -> ::std::ops::ControlFlow<Self::Break> {
258 ControlFlow::Continue(())
259 }
260}
261
262#[derive(Visitor)]
263struct PartialMonomorphizer<'a> {
264 ctx: &'a mut TransformCtx,
265 span: Span,
267 instantiate_types: bool,
269 infected_types: HashSet<TypeDeclId>,
271 generic_params: HashMap<ItemId, GenericParams>,
276 partial_mono_shapes: SeqHashMap<(ItemId, MutabilityShape), ItemId>,
280 reverse_shape_map: HashMap<ItemId, (ItemId, MutabilityShape)>,
282 to_process: VecDeque<ItemId>,
284}
285
286impl<'a> PartialMonomorphizer<'a> {
287 pub fn new(ctx: &'a mut TransformCtx, instantiate_types: bool) -> Self {
288 use petgraph::graphmap::DiGraphMap;
289 use petgraph::visit::Dfs;
290 use petgraph::visit::Walker;
291
292 let infected_types: HashSet<_> = {
294 let mut graph: DiGraphMap<Option<TypeDeclId>, ()> = Default::default();
299 for (id, tdecl) in ctx.translated.type_decls.iter_indexed() {
300 tdecl.dyn_visit(|x: &Ty| match x.kind() {
301 TyKind::Ref(_, _, RefKind::Mut) => {
302 graph.add_edge(None, Some(id), ());
303 }
304 TyKind::Adt(tref) if let TypeId::Adt(other_id) = tref.id => {
305 graph.add_edge(Some(other_id), Some(id), ());
306 }
307 _ => {}
308 });
309 }
310 let start = graph.add_node(None);
311 Dfs::new(&graph, start)
312 .iter(&graph)
313 .filter_map(|opt_id| opt_id)
314 .collect()
315 };
316
317 let generic_params: HashMap<ItemId, GenericParams> = ctx
319 .translated
320 .all_items()
321 .map(|item| (item.id(), item.generic_params().clone()))
322 .collect();
323
324 let to_process = ctx.translated.all_ids().collect();
326 PartialMonomorphizer {
327 ctx,
328 span: Span::dummy(),
329 instantiate_types,
330 infected_types,
331 generic_params,
332 to_process,
333 partial_mono_shapes: SeqHashMap::default(),
334 reverse_shape_map: Default::default(),
335 }
336 }
337
338 fn is_infected(&self, ty: &Ty) -> bool {
341 match ty.kind() {
342 TyKind::Ref(_, _, RefKind::Mut) => true,
343 TyKind::Ref(_, ty, _)
344 | TyKind::RawPtr(ty, _)
345 | TyKind::Array(ty, _)
346 | TyKind::Slice(ty) => self.is_infected(ty),
347 TyKind::Adt(tref) => {
348 let ty_infected =
349 matches!(&tref.id, TypeId::Adt(id) if self.infected_types.contains(id));
350 let args_infected = if tref.id.is_adt() && self.instantiate_types {
351 false
356 } else {
357 tref.generics.types.iter().any(|ty| self.is_infected(ty))
358 };
359 ty_infected || args_infected
360 }
361 TyKind::FnDef(..) | TyKind::FnPtr(..) => {
362 register_error!(
363 self.ctx,
364 self.span,
365 "function pointers are unsupported with `--monomorphize-mut`"
366 );
367 false
368 }
369 TyKind::DynTrait(_) => {
370 register_error!(
371 self.ctx,
372 self.span,
373 "`dyn Trait` is unsupported with `--monomorphize-mut`"
374 );
375 false
376 }
377 TyKind::TypeVar(..)
378 | TyKind::Literal(..)
379 | TyKind::Never
380 | TyKind::TraitType(..)
381 | TyKind::PtrMetadata(..)
382 | TyKind::Error(_) => false,
383 }
384 }
385
386 fn process_generics(&mut self, id: ItemId, generics: &GenericArgs) -> Option<DeclRef<ItemId>> {
390 if !generics.types.iter().any(|ty| self.is_infected(ty)) {
391 return None;
392 }
393
394 let mut new_generics;
397 let (id, generics) = if let Some(&(base_id, ref shape)) = self.reverse_shape_map.get(&id) {
398 new_generics = shape.clone().apply(generics);
399 let _ = self.visit(&mut new_generics); (base_id, &new_generics)
401 } else {
402 (id, generics)
403 };
404
405 let item_params = self.generic_params.get(&id)?;
407 let (shape, shape_args) =
408 MutabilityShapeBuilder::compute_shape(self, item_params, generics);
409
410 let new_params = shape.params.clone();
412 let key: (ItemId, MutabilityShape) = (id, shape);
413 let new_id = *self
414 .partial_mono_shapes
415 .entry(key.clone())
416 .or_insert_with(|| {
417 let new_id = match id {
418 ItemId::Type(_) => {
419 let new_id = self.ctx.translated.type_decls.reserve_slot();
420 self.infected_types.insert(new_id);
421 new_id.into()
422 }
423 ItemId::Fun(_) => self.ctx.translated.fun_decls.reserve_slot().into(),
424 ItemId::Global(_) => self.ctx.translated.global_decls.reserve_slot().into(),
425 ItemId::TraitDecl(_) => self.ctx.translated.trait_decls.reserve_slot().into(),
426 ItemId::TraitImpl(_) => self.ctx.translated.trait_impls.reserve_slot().into(),
427 };
428 self.generic_params.insert(new_id, new_params);
429 self.reverse_shape_map.insert(new_id, key);
430 self.to_process.push_back(new_id);
431 new_id
432 });
433
434 let fmt_ctx = self.ctx.into_fmt();
435 trace!(
436 "processing {}{}\n output: {}{}",
437 id.with_ctx(&fmt_ctx),
438 generics.with_ctx(&fmt_ctx),
439 new_id.with_ctx(&fmt_ctx),
440 shape_args.with_ctx(&fmt_ctx),
441 );
442 Some(DeclRef {
443 id: new_id,
444 generics: Box::new(shape_args),
445 trait_ref: None,
446 })
447 }
448
449 pub fn process_item(&mut self, item: &mut ItemRefMut<'_>) {
453 let _ = item.drive_mut(self);
454 }
455
456 pub fn create_pending_instantiation(&mut self, new_id: ItemId) -> ItemByVal {
461 let (orig_id, shape) = &self.reverse_shape_map[&new_id];
462 let mut decl = self
463 .ctx
464 .translated
465 .get_item(*orig_id)
466 .unwrap()
467 .clone()
468 .substitute(&shape.skip_binder);
469
470 let mut decl_mut = decl.as_mut();
471 decl_mut.set_id(new_id);
472 *decl_mut.generic_params() = shape.params.clone();
473
474 let name_ref = &mut decl_mut.item_meta().name;
475 *name_ref = mem::take(name_ref).instantiate(shape.clone());
476 self.ctx
477 .translated
478 .item_names
479 .insert(new_id, decl.as_ref().item_meta().name.clone());
480
481 decl
482 }
483}
484
485impl VisitorWithSpan for PartialMonomorphizer<'_> {
486 fn current_span(&mut self) -> &mut Span {
487 &mut self.span
488 }
489}
490impl VisitAstMut for PartialMonomorphizer<'_> {
491 fn visit<'a, T: AstVisitable>(&'a mut self, x: &mut T) -> ControlFlow<Self::Break> {
492 VisitWithSpan::new(self).visit(x)
494 }
495
496 fn exit_type_decl_ref(&mut self, x: &mut TypeDeclRef) {
497 if self.instantiate_types
498 && let TypeId::Adt(id) = x.id
499 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
500 {
501 *x = new_decl_ref.try_into().unwrap()
502 }
503 }
504 fn exit_fn_ptr(&mut self, x: &mut FnPtr) {
505 if let FnPtrKind::Fun(FunId::Regular(id)) = *x.kind
508 && let Some(new_decl_ref) = self.process_generics(id.into(), &x.generics)
509 {
510 *x = new_decl_ref.try_into().unwrap()
511 }
512 }
513 fn exit_fun_decl_ref(&mut self, x: &mut FunDeclRef) {
514 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
515 *x = new_decl_ref.try_into().unwrap()
516 }
517 }
518 fn exit_global_decl_ref(&mut self, x: &mut GlobalDeclRef) {
519 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
520 *x = new_decl_ref.try_into().unwrap()
521 }
522 }
523 fn exit_trait_decl_ref(&mut self, x: &mut TraitDeclRef) {
524 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
525 *x = new_decl_ref.try_into().unwrap()
526 }
527 }
528 fn exit_trait_impl_ref(&mut self, x: &mut TraitImplRef) {
529 if let Some(new_decl_ref) = self.process_generics(x.id.into(), &x.generics) {
530 *x = new_decl_ref.try_into().unwrap()
531 }
532 }
533}
534
535pub struct Transform;
536impl TransformPass for Transform {
537 fn transform_ctx(&self, ctx: &mut TransformCtx) {
538 let Some(include_types) = ctx.options.monomorphize_mut else {
539 return;
540 };
541 let mut visitor =
543 PartialMonomorphizer::new(ctx, matches!(include_types, MonomorphizeMut::All));
544 while let Some(id) = visitor.to_process.pop_front() {
545 let mut decl = if visitor.reverse_shape_map.get(&id).is_some() {
548 visitor.create_pending_instantiation(id)
550 } else {
551 match visitor.ctx.translated.remove_item(id) {
554 Some(decl) => decl,
555 None => continue,
556 }
557 };
558 visitor.process_item(&mut decl.as_mut());
561 visitor.ctx.translated.set_item_slot(id, decl);
563 }
564 }
565}