rustc_type_ir_macros/
lib.rs

1use quote::{ToTokens, quote};
2use syn::visit_mut::VisitMut;
3use syn::{Attribute, parse_quote};
4use synstructure::decl_derive;
5
6decl_derive!(
7    [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
8);
9decl_derive!(
10    [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
11);
12decl_derive!(
13    [Lift_Generic] => lift_derive
14);
15
16fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
17    let mut ignored = false;
18    attrs.iter().for_each(|attr| {
19        if !attr.path().is_ident(name) {
20            return;
21        }
22        let _ = attr.parse_nested_meta(|nested| {
23            if nested.path.is_ident(meta) {
24                ignored = true;
25            }
26            Ok(())
27        });
28    });
29
30    ignored
31}
32
33fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
34    if let syn::Data::Union(_) = s.ast().data {
35        panic!("cannot derive on union")
36    }
37
38    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
39        s.add_impl_generic(parse_quote! { I });
40    }
41
42    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
43
44    s.add_where_predicate(parse_quote! { I: Interner });
45    s.add_bounds(synstructure::AddBounds::Fields);
46    let body_visit = s.each(|bind| {
47        quote! {
48            match ::rustc_type_ir::VisitorResult::branch(
49                ::rustc_type_ir::TypeVisitable::visit_with(#bind, __visitor)
50            ) {
51                ::core::ops::ControlFlow::Continue(()) => {},
52                ::core::ops::ControlFlow::Break(r) => {
53                    return ::rustc_type_ir::VisitorResult::from_residual(r);
54                },
55            }
56        }
57    });
58    s.bind_with(|_| synstructure::BindStyle::Move);
59
60    s.bound_impl(
61        quote!(::rustc_type_ir::TypeVisitable<I>),
62        quote! {
63            fn visit_with<__V: ::rustc_type_ir::TypeVisitor<I>>(
64                &self,
65                __visitor: &mut __V
66            ) -> __V::Result {
67                match *self { #body_visit }
68                <__V::Result as ::rustc_type_ir::VisitorResult>::output()
69            }
70        },
71    )
72}
73
74fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
75    if let syn::Data::Union(_) = s.ast().data {
76        panic!("cannot derive on union")
77    }
78
79    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
80        s.add_impl_generic(parse_quote! { I });
81    }
82
83    s.add_where_predicate(parse_quote! { I: Interner });
84    s.add_bounds(synstructure::AddBounds::Fields);
85    s.bind_with(|_| synstructure::BindStyle::Move);
86    let body_try_fold = s.each_variant(|vi| {
87        let bindings = vi.bindings();
88        vi.construct(|_, index| {
89            let bind = &bindings[index];
90
91            // retain value of fields with #[type_foldable(identity)]
92            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
93                bind.to_token_stream()
94            } else {
95                quote! {
96                    ::rustc_type_ir::TypeFoldable::try_fold_with(#bind, __folder)?
97                }
98            }
99        })
100    });
101
102    let body_fold = s.each_variant(|vi| {
103        let bindings = vi.bindings();
104        vi.construct(|_, index| {
105            let bind = &bindings[index];
106
107            // retain value of fields with #[type_foldable(identity)]
108            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
109                bind.to_token_stream()
110            } else {
111                quote! {
112                    ::rustc_type_ir::TypeFoldable::fold_with(#bind, __folder)
113                }
114            }
115        })
116    });
117
118    // We filter fields which get ignored and don't require them to implement
119    // `TypeFoldable`. We do so after generating `body_fold` as we still need
120    // to generate code for them.
121    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
122    s.add_bounds(synstructure::AddBounds::Fields);
123    s.bound_impl(
124        quote!(::rustc_type_ir::TypeFoldable<I>),
125        quote! {
126            fn try_fold_with<__F: ::rustc_type_ir::FallibleTypeFolder<I>>(
127                self,
128                __folder: &mut __F
129            ) -> Result<Self, __F::Error> {
130                Ok(match self { #body_try_fold })
131            }
132
133            fn fold_with<__F: ::rustc_type_ir::TypeFolder<I>>(
134                self,
135                __folder: &mut __F
136            ) -> Self {
137                match self { #body_fold }
138            }
139        },
140    )
141}
142
143fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
144    if let syn::Data::Union(_) = s.ast().data {
145        panic!("cannot derive on union")
146    }
147
148    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
149        s.add_impl_generic(parse_quote! { I });
150    }
151
152    s.add_bounds(synstructure::AddBounds::None);
153    s.add_where_predicate(parse_quote! { I: Interner });
154    s.add_impl_generic(parse_quote! { J });
155    s.add_where_predicate(parse_quote! { J: Interner });
156
157    let mut wc = vec![];
158    s.bind_with(|_| synstructure::BindStyle::Move);
159    let body_fold = s.each_variant(|vi| {
160        let bindings = vi.bindings();
161        vi.construct(|field, index| {
162            let ty = field.ty.clone();
163            let lifted_ty = lift(ty.clone());
164            wc.push(parse_quote! { #ty: ::rustc_type_ir::lift::Lift<J, Lifted = #lifted_ty> });
165            let bind = &bindings[index];
166            quote! {
167                #bind.lift_to_interner(interner)?
168            }
169        })
170    });
171    for wc in wc {
172        s.add_where_predicate(wc);
173    }
174
175    let (_, ty_generics, _) = s.ast().generics.split_for_impl();
176    let name = s.ast().ident.clone();
177    let self_ty: syn::Type = parse_quote! { #name #ty_generics };
178    let lifted_ty = lift(self_ty);
179
180    s.bound_impl(
181        quote!(::rustc_type_ir::lift::Lift<J>),
182        quote! {
183            type Lifted = #lifted_ty;
184
185            fn lift_to_interner(
186                self,
187                interner: J,
188            ) -> Option<Self::Lifted> {
189                Some(match self { #body_fold })
190            }
191        },
192    )
193}
194
195fn lift(mut ty: syn::Type) -> syn::Type {
196    struct ItoJ;
197    impl VisitMut for ItoJ {
198        fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
199            if i.qself.is_none() {
200                if let Some(first) = i.path.segments.first_mut() {
201                    if first.ident == "I" {
202                        *first = parse_quote! { J };
203                    }
204                }
205            }
206            syn::visit_mut::visit_type_path_mut(self, i);
207        }
208    }
209
210    ItoJ.visit_type_mut(&mut ty);
211
212    ty
213}