1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use std::vec::Vec;
4use syn::{parse2, Data, DataEnum, DeriveInput, Fields, Ident, Type};
5
6use self::EnumMethodKind::*;
7
8fn to_snake_case(s: &str) -> String {
12 let mut snake_case = String::new();
13
14 let mut last_is_lowercase = false;
21
22 for c in s.chars() {
23 if c.is_uppercase() {
24 if last_is_lowercase {
25 snake_case.push('_');
26 }
27 last_is_lowercase = false;
28 snake_case.push(c.to_lowercase().next().unwrap());
29 } else {
30 last_is_lowercase = true;
31 snake_case.push(c);
32 }
33 }
34
35 snake_case
36}
37
38struct MatchPattern {
39 variant_name: String,
41 pattern: TokenStream,
44 num_args: usize,
47 pattern_vars: Vec<Ident>,
50 arg_types: Vec<Type>,
52}
53
54fn generate_variant_match_patterns(enum_name: &Ident, data: &DataEnum) -> Vec<MatchPattern> {
56 let mut patterns: Vec<MatchPattern> = vec![];
57 for variant in &data.variants {
58 let fields = &variant.fields;
61 let num_vars = fields.len();
62 let vars: Vec<Ident> = (0..num_vars)
63 .map(|i| Ident::new(&format!("_x{i}"), Span::mixed_site()))
64 .collect();
65 let vartypes: Vec<_> = fields.iter().map(|f| f.ty.clone()).collect();
66
67 let pattern_vars = match fields {
68 Fields::Named(_) => {
69 let field_names: Vec<_> = fields.iter().map(|f| &f.ident).collect();
70 quote!({ #(#field_names : #vars,)* })
71 }
72 Fields::Unnamed(_) => {
73 quote!((#(#vars,)*))
74 }
75 Fields::Unit => quote!(),
76 };
77 let variant_name = &variant.ident;
78 let pattern = quote!(#enum_name :: #variant_name #pattern_vars);
79
80 patterns.push(MatchPattern {
81 variant_name: variant.ident.to_string(),
82 pattern,
83 num_args: num_vars,
84 pattern_vars: vars,
85 arg_types: vartypes,
86 });
87 }
88
89 patterns
90}
91
92pub fn derive_variant_name(item: TokenStream) -> TokenStream {
95 let ast: DeriveInput = parse2(item).unwrap();
97
98 let Data::Enum(data) = &ast.data else {
100 panic!("VariantName macro can only be called on enums");
101 };
102 let patterns = generate_variant_match_patterns(&ast.ident, data);
103 let match_branches: Vec<TokenStream> = patterns
104 .into_iter()
105 .map(|mp| {
106 let pattern = &mp.pattern;
107 let name = &mp.variant_name;
108 quote!( #pattern => #name )
109 })
110 .collect();
111
112 let adt_name = &ast.ident;
113 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
114 quote!(
115 impl #impl_generics #adt_name #ty_generics #where_clause {
116 pub fn variant_name(&self) -> &'static str {
117 match self {
118 #(#match_branches,)*
119 }
120 }
121 }
122 )
123}
124
125pub fn derive_variant_index_arity(item: TokenStream) -> TokenStream {
129 let ast: DeriveInput = parse2(item).unwrap();
131
132 let Data::Enum(data) = &ast.data else {
134 panic!("VariantIndex macro can only be called on enums");
135 };
136 let patterns = generate_variant_match_patterns(&ast.ident, data);
137 let match_branches: Vec<TokenStream> = patterns
138 .into_iter()
139 .enumerate()
140 .map(|(i, mp)| {
141 let pattern = &mp.pattern;
142 let i = i as u32;
143 let arity = mp.num_args;
144 quote!( #pattern => (#i, #arity) )
145 })
146 .collect();
147
148 let adt_name = &ast.ident;
149 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
150 quote!(
151 impl #impl_generics #adt_name #ty_generics #where_clause {
152 pub fn variant_index_arity(&self) -> (u32, usize) {
153 match self {
154 #(#match_branches,)*
155 }
156 }
157 }
158 )
159}
160
161#[derive(PartialEq, Eq)]
162pub enum EnumMethodKind {
163 EnumIsA,
164 EnumAsGetters,
165 EnumAsMutGetters,
166 EnumToGetters,
167}
168
169impl EnumMethodKind {
170 fn variant_name(&self) -> &'static str {
173 match self {
174 EnumIsA => "EnumIsA",
175 EnumAsGetters => "EnumAsGetters",
176 EnumAsMutGetters => "EnumAsMutGetters",
177 EnumToGetters => "EnumToGetters",
178 }
179 }
180}
181
182pub fn derive_enum_variant_method(item: TokenStream, method_kind: EnumMethodKind) -> TokenStream {
185 let ast: DeriveInput = parse2(item).unwrap();
187
188 let adt_name = &ast.ident;
190
191 let Data::Enum(data) = &ast.data else {
193 panic!(
194 "{} macro can only be called on enums",
195 method_kind.variant_name()
196 );
197 };
198 let patterns = generate_variant_match_patterns(&ast.ident, data);
199 let methods: Vec<TokenStream> = patterns
200 .into_iter()
201 .map(|mp| {
202 let pattern = &mp.pattern;
203 let name_prefix = match method_kind {
204 EnumIsA => "is_",
205 EnumAsGetters | EnumAsMutGetters => "as_",
206 EnumToGetters => "to_",
207 };
208 let name_suffix = match method_kind {
209 EnumAsMutGetters => "_mut",
210 _ => "",
211 };
212 let ref_kind = match method_kind {
213 EnumAsGetters | EnumIsA => quote!(&),
214 EnumAsMutGetters => quote!(&mut),
215 EnumToGetters => quote!(),
216 };
217 let variant_name = to_snake_case(&mp.variant_name);
220 let method_name = format!("{name_prefix}{variant_name}{name_suffix}");
221 let method_name = Ident::new(&method_name, Span::call_site());
222 match method_kind {
223 EnumIsA => {
224 quote!(
225 pub fn #method_name(#ref_kind self) -> bool {
226 #[allow(unreachable)]
227 match self {
228 #pattern => true,
229 _ => false,
230 }
231 }
232 )
233 }
234 EnumAsGetters | EnumAsMutGetters | EnumToGetters => {
235 let vars = &mp.pattern_vars;
236 let tys = &mp.arg_types;
237 quote!(
238 pub fn #method_name(#ref_kind self) -> Option<( #(#ref_kind #tys),* )> {
239 #[allow(unreachable)]
240 match self {
241 #pattern => Some(( #(#vars),* )),
242 _ => None,
243 }
244 }
245 )
246 }
247 }
248 })
249 .collect();
250
251 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
252 quote!(
253 impl #impl_generics #adt_name #ty_generics #where_clause {
254 #(#methods)*
255 }
256 )
257}
258
259#[test]
260fn test_snake_case() {
261 let s = to_snake_case("ConstantValue");
262 println!("{}", s);
263 assert!(s == "constant_value");
264}
265
266#[test]
267fn test_generics() {
268 let s = quote!(
269 enum Foo<T: Clone>
270 where
271 T: Copy,
272 {
273 Variant1(T),
274 Variant2 { x: u32 },
275 Variant3,
276 }
277 );
278 assert_tokens_eq::assert_tokens_eq!(
279 derive_variant_index_arity(s.clone()),
280 quote! {
281 impl<T: Clone,> Foo<T,>
282 where
283 T: Copy,
284 {
285 pub fn variant_index_arity(&self) -> (u32, usize) {
286 match self {
287 Foo::Variant1(_x0) => (0u32, 1usize),
288 Foo::Variant2 { x: _x0 } => (1u32, 1usize),
289 Foo::Variant3 => (2u32, 0usize),
290 }
291 }
292 }
293 }
294 );
295 assert_tokens_eq::assert_tokens_eq!(
296 derive_enum_variant_method(s, EnumAsMutGetters),
297 quote! {
298 impl<T: Clone,> Foo<T,>
299 where
300 T: Copy,
301 {
302 pub fn as_variant1_mut(&mut self) -> (&mut T) {
303 #[allow(unreachable)]
304 match self {
305 Foo::Variant1(_x0) => (_x0),
306 _ => unreachable!("Foo::as_variant1_mut: Not the proper variant"),
307 }
308 }
309 pub fn as_variant2_mut(&mut self) -> (&mut u32) {
310 #[allow(unreachable)]
311 match self {
312 Foo::Variant2 { x: _x0 } => (_x0),
313 _ => unreachable!("Foo::as_variant2_mut: Not the proper variant"),
314 }
315 }
316 pub fn as_variant3_mut(&mut self) -> () {
317 #[allow(unreachable)]
318 match self {
319 Foo::Variant3 => (),
320 _ => unreachable!("Foo::as_variant3_mut: Not the proper variant"),
321 }
322 }
323 }
324 }
325 );
326}