rustc_type_ir_macros/
lib.rs1use quote::{ToTokens, quote};
2use syn::visit_mut::VisitMut;
3use syn::{Attribute, parse_quote};
4use synstructure::decl_derive;
5
6decl_derive!(
7 [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
8);
9decl_derive!(
10 [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
11);
12decl_derive!(
13 [Lift_Generic] => lift_derive
14);
15
16fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
17 let mut ignored = false;
18 attrs.iter().for_each(|attr| {
19 if !attr.path().is_ident(name) {
20 return;
21 }
22 let _ = attr.parse_nested_meta(|nested| {
23 if nested.path.is_ident(meta) {
24 ignored = true;
25 }
26 Ok(())
27 });
28 });
29
30 ignored
31}
32
33fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
34 if let syn::Data::Union(_) = s.ast().data {
35 panic!("cannot derive on union")
36 }
37
38 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
39 s.add_impl_generic(parse_quote! { I });
40 }
41
42 s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
43
44 s.add_where_predicate(parse_quote! { I: Interner });
45 s.add_bounds(synstructure::AddBounds::Fields);
46 let body_visit = s.each(|bind| {
47 quote! {
48 match ::rustc_type_ir::VisitorResult::branch(
49 ::rustc_type_ir::TypeVisitable::visit_with(#bind, __visitor)
50 ) {
51 ::core::ops::ControlFlow::Continue(()) => {},
52 ::core::ops::ControlFlow::Break(r) => {
53 return ::rustc_type_ir::VisitorResult::from_residual(r);
54 },
55 }
56 }
57 });
58 s.bind_with(|_| synstructure::BindStyle::Move);
59
60 s.bound_impl(
61 quote!(::rustc_type_ir::TypeVisitable<I>),
62 quote! {
63 fn visit_with<__V: ::rustc_type_ir::TypeVisitor<I>>(
64 &self,
65 __visitor: &mut __V
66 ) -> __V::Result {
67 match *self { #body_visit }
68 <__V::Result as ::rustc_type_ir::VisitorResult>::output()
69 }
70 },
71 )
72}
73
74fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
75 if let syn::Data::Union(_) = s.ast().data {
76 panic!("cannot derive on union")
77 }
78
79 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
80 s.add_impl_generic(parse_quote! { I });
81 }
82
83 s.add_where_predicate(parse_quote! { I: Interner });
84 s.add_bounds(synstructure::AddBounds::Fields);
85 s.bind_with(|_| synstructure::BindStyle::Move);
86 let body_try_fold = s.each_variant(|vi| {
87 let bindings = vi.bindings();
88 vi.construct(|_, index| {
89 let bind = &bindings[index];
90
91 if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
93 bind.to_token_stream()
94 } else {
95 quote! {
96 ::rustc_type_ir::TypeFoldable::try_fold_with(#bind, __folder)?
97 }
98 }
99 })
100 });
101
102 let body_fold = s.each_variant(|vi| {
103 let bindings = vi.bindings();
104 vi.construct(|_, index| {
105 let bind = &bindings[index];
106
107 if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
109 bind.to_token_stream()
110 } else {
111 quote! {
112 ::rustc_type_ir::TypeFoldable::fold_with(#bind, __folder)
113 }
114 }
115 })
116 });
117
118 s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
122 s.add_bounds(synstructure::AddBounds::Fields);
123 s.bound_impl(
124 quote!(::rustc_type_ir::TypeFoldable<I>),
125 quote! {
126 fn try_fold_with<__F: ::rustc_type_ir::FallibleTypeFolder<I>>(
127 self,
128 __folder: &mut __F
129 ) -> Result<Self, __F::Error> {
130 Ok(match self { #body_try_fold })
131 }
132
133 fn fold_with<__F: ::rustc_type_ir::TypeFolder<I>>(
134 self,
135 __folder: &mut __F
136 ) -> Self {
137 match self { #body_fold }
138 }
139 },
140 )
141}
142
143fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
144 if let syn::Data::Union(_) = s.ast().data {
145 panic!("cannot derive on union")
146 }
147
148 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
149 s.add_impl_generic(parse_quote! { I });
150 }
151
152 s.add_bounds(synstructure::AddBounds::None);
153 s.add_where_predicate(parse_quote! { I: Interner });
154 s.add_impl_generic(parse_quote! { J });
155 s.add_where_predicate(parse_quote! { J: Interner });
156
157 let mut wc = vec![];
158 s.bind_with(|_| synstructure::BindStyle::Move);
159 let body_fold = s.each_variant(|vi| {
160 let bindings = vi.bindings();
161 vi.construct(|field, index| {
162 let ty = field.ty.clone();
163 let lifted_ty = lift(ty.clone());
164 wc.push(parse_quote! { #ty: ::rustc_type_ir::lift::Lift<J, Lifted = #lifted_ty> });
165 let bind = &bindings[index];
166 quote! {
167 #bind.lift_to_interner(interner)?
168 }
169 })
170 });
171 for wc in wc {
172 s.add_where_predicate(wc);
173 }
174
175 let (_, ty_generics, _) = s.ast().generics.split_for_impl();
176 let name = s.ast().ident.clone();
177 let self_ty: syn::Type = parse_quote! { #name #ty_generics };
178 let lifted_ty = lift(self_ty);
179
180 s.bound_impl(
181 quote!(::rustc_type_ir::lift::Lift<J>),
182 quote! {
183 type Lifted = #lifted_ty;
184
185 fn lift_to_interner(
186 self,
187 interner: J,
188 ) -> Option<Self::Lifted> {
189 Some(match self { #body_fold })
190 }
191 },
192 )
193}
194
195fn lift(mut ty: syn::Type) -> syn::Type {
196 struct ItoJ;
197 impl VisitMut for ItoJ {
198 fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
199 if i.qself.is_none() {
200 if let Some(first) = i.path.segments.first_mut() {
201 if first.ident == "I" {
202 *first = parse_quote! { J };
203 }
204 }
205 }
206 syn::visit_mut::visit_type_path_mut(self, i);
207 }
208 }
209
210 ItoJ.visit_type_mut(&mut ty);
211
212 ty
213}