charon_lib/transform/
lift_associated_item_clauses.rs

1//! Move clauses on non-gat associated types to be parent clauses. The distinction is not semantically
2//! meaningful. We should ideally to this directly when translating but this is currently
3//! difficult; instead we do this as a post-processing pass.
4use std::collections::HashMap;
5use std::mem;
6
7use crate::{ast::*, ids::Vector};
8
9use super::{TransformCtx, ctx::TransformPass};
10
11pub struct Transform;
12impl TransformPass for Transform {
13    fn transform_ctx(&self, ctx: &mut TransformCtx) {
14        // For each trait, we move the item-local clauses to be top-level parent clauses, and
15        // record the mapping from the old to the new ids.
16        let trait_item_clause_ids: Vector<
17            TraitDeclId,
18            HashMap<TraitItemName, Vector<TraitClauseId, TraitClauseId>>,
19        > = ctx.translated.trait_decls.map_ref_mut(|decl| {
20            decl.types
21                .iter_mut()
22                .filter(|assoc_ty| !assoc_ty.params.has_explicits())
23                .map(|assoc_ty| {
24                    let id_map =
25                        mem::take(&mut assoc_ty.skip_binder.implied_clauses).map(|clause| {
26                            let mut clause = clause.move_from_under_binder().unwrap();
27                            decl.parent_clauses.push_with(|id| {
28                                clause.clause_id = id;
29                                clause
30                            })
31                        });
32                    if assoc_ty.params.trait_clauses.is_empty() {
33                        // Move non-trait-clause-predicates of non-GAT types to be predicates on
34                        // the trait itself.
35                        decl.generics.take_predicates_from(
36                            mem::take(&mut assoc_ty.params)
37                                .move_from_under_binder()
38                                .unwrap(),
39                        );
40                    }
41                    (assoc_ty.name().clone(), id_map)
42                })
43                .collect()
44        });
45
46        // Move the item-local trait refs to match what we did in the trait declarations.
47        for timpl in ctx.translated.trait_impls.iter_mut() {
48            for (_, assoc_ty) in &mut timpl.types {
49                if !assoc_ty.params.has_explicits() {
50                    for trait_ref in mem::take(&mut assoc_ty.skip_binder.implied_trait_refs) {
51                        let trait_ref = trait_ref.move_from_under_binder().unwrap();
52                        // Note: this assumes that we listed the types in the same order as in the
53                        // trait decl, which we do.
54                        timpl.parent_trait_refs.push(trait_ref);
55                    }
56                }
57            }
58        }
59
60        // Update trait refs.
61        ctx.translated.dyn_visit_mut(|trkind: &mut TraitRefKind| {
62            use TraitRefKind::*;
63            match trkind {
64                ItemClause(..) => take_mut::take(trkind, |trkind| {
65                    let ItemClause(tref, item_name, item_clause_id) = trkind else {
66                        unreachable!()
67                    };
68                    let new_id = (|| {
69                        let new_id = *trait_item_clause_ids
70                            .get(tref.trait_decl_ref.skip_binder.id)?
71                            .get(&item_name)?
72                            .get(item_clause_id)?;
73                        Some(new_id)
74                    })();
75                    match new_id {
76                        Some(new_id) => ParentClause(tref, new_id),
77                        None => ItemClause(tref, item_name, item_clause_id),
78                    }
79                }),
80                BuiltinOrAuto {
81                    parent_trait_refs,
82                    types,
83                    ..
84                } => {
85                    for (_, assoc_ty) in types {
86                        for tref in std::mem::take(&mut assoc_ty.implied_trait_refs) {
87                            // Note: this assumes that we listed the types in the same order as in
88                            // the trait decl, which we do.
89                            parent_trait_refs.push(tref);
90                        }
91                    }
92                }
93                _ => {}
94            }
95        });
96    }
97}