macros/
enum_helpers.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use std::vec::Vec;
4use syn::{parse2, Data, DataEnum, DeriveInput, Fields, Ident, Type};
5
6use self::EnumMethodKind::*;
7
8/// We initially used the convert-case crate, but it converts names like "I32"
9/// to "i_32", while we want to get "i32". We thus reimplemented our own converter
10/// (which removes one dependency at the same time).
11fn to_snake_case(s: &str) -> String {
12    let mut snake_case = String::new();
13
14    // We need to keep track of whether the last treated character was
15    // lowercase (or not) to prevent this kind of transformations:
16    // "VARIANT" -> "v_a_r_i_a_n_t"
17    // Note that if we remember whether the last character was uppercase instead,
18    // we get things like this:
19    // "I32" -> "I3_2"
20    let mut last_is_lowercase = false;
21
22    for c in s.chars() {
23        if c.is_uppercase() {
24            if last_is_lowercase {
25                snake_case.push('_');
26            }
27            last_is_lowercase = false;
28            snake_case.push(c.to_lowercase().next().unwrap());
29        } else {
30            last_is_lowercase = true;
31            snake_case.push(c);
32        }
33    }
34
35    snake_case
36}
37
38struct MatchPattern {
39    /// The variant name, as a string.
40    variant_name: String,
41    /// The match pattern.
42    /// For instance: `List::Cons(x0, x1)`
43    pattern: TokenStream,
44    /// The number of arguments in the match pattern (including anonymous
45    /// arguments).
46    num_args: usize,
47    /// The variables we introduced in the match pattern.
48    /// `["x0", "x1"]` if the pattern is `List::Cons(hd, tl)`.
49    pattern_vars: Vec<Ident>,
50    /// The types of the variables introduced in the match pattern
51    arg_types: Vec<Type>,
52}
53
54/// Generate matching patterns for an enumeration.
55fn generate_variant_match_patterns(enum_name: &Ident, data: &DataEnum) -> Vec<MatchPattern> {
56    let mut patterns: Vec<MatchPattern> = vec![];
57    for variant in &data.variants {
58        // Compute the pattern (without the variant constructor), the list
59        // of introduced arguments and the list of field types.
60        let fields = &variant.fields;
61        let num_vars = fields.len();
62        let vars: Vec<Ident> = (0..num_vars)
63            .map(|i| Ident::new(&format!("_x{i}"), Span::mixed_site()))
64            .collect();
65        let vartypes: Vec<_> = fields.iter().map(|f| f.ty.clone()).collect();
66
67        let pattern_vars = match fields {
68            Fields::Named(_) => {
69                let field_names: Vec<_> = fields.iter().map(|f| &f.ident).collect();
70                quote!({ #(#field_names : #vars,)* })
71            }
72            Fields::Unnamed(_) => {
73                quote!((#(#vars,)*))
74            }
75            Fields::Unit => quote!(),
76        };
77        let variant_name = &variant.ident;
78        let pattern = quote!(#enum_name :: #variant_name #pattern_vars);
79
80        patterns.push(MatchPattern {
81            variant_name: variant.ident.to_string(),
82            pattern,
83            num_args: num_vars,
84            pattern_vars: vars,
85            arg_types: vartypes,
86        });
87    }
88
89    patterns
90}
91
92/// Macro to derive a function `fn variant_name(&self) -> String` printing the
93/// constructor of an enumeration. Only works on enumerations, of course.
94pub fn derive_variant_name(item: TokenStream) -> TokenStream {
95    // Parse the input
96    let ast: DeriveInput = parse2(item).unwrap();
97
98    // Generate the code for the matches
99    let Data::Enum(data) = &ast.data else {
100        panic!("VariantName macro can only be called on enums");
101    };
102    let patterns = generate_variant_match_patterns(&ast.ident, data);
103    let match_branches: Vec<TokenStream> = patterns
104        .into_iter()
105        .map(|mp| {
106            let pattern = &mp.pattern;
107            let name = &mp.variant_name;
108            quote!( #pattern => #name )
109        })
110        .collect();
111
112    let adt_name = &ast.ident;
113    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
114    quote!(
115        impl #impl_generics #adt_name #ty_generics #where_clause {
116            pub fn variant_name(&self) -> &'static str {
117                match self {
118                    #(#match_branches,)*
119                }
120            }
121        }
122    )
123}
124
125/// Macro to derive a function `fn variant_index_arity(&self) -> (u32, usize)`
126/// the pair (variant index, variant arity).
127/// Only works on enumerations, of course.
128pub fn derive_variant_index_arity(item: TokenStream) -> TokenStream {
129    // Parse the input
130    let ast: DeriveInput = parse2(item).unwrap();
131
132    // Generate the code for the matches
133    let Data::Enum(data) = &ast.data else {
134        panic!("VariantIndex macro can only be called on enums");
135    };
136    let patterns = generate_variant_match_patterns(&ast.ident, data);
137    let match_branches: Vec<TokenStream> = patterns
138        .into_iter()
139        .enumerate()
140        .map(|(i, mp)| {
141            let pattern = &mp.pattern;
142            let i = i as u32;
143            let arity = mp.num_args;
144            quote!( #pattern => (#i, #arity) )
145        })
146        .collect();
147
148    let adt_name = &ast.ident;
149    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
150    quote!(
151        impl #impl_generics #adt_name #ty_generics #where_clause {
152            pub fn variant_index_arity(&self) -> (u32, usize) {
153                match self {
154                    #(#match_branches,)*
155                }
156            }
157        }
158    )
159}
160
161#[derive(PartialEq, Eq)]
162pub enum EnumMethodKind {
163    EnumIsA,
164    EnumAsGetters,
165    EnumAsMutGetters,
166    EnumToGetters,
167}
168
169impl EnumMethodKind {
170    /// We have to write this by hand: we can't use the macros defined above on
171    /// the declarations of this file...
172    fn variant_name(&self) -> &'static str {
173        match self {
174            EnumIsA => "EnumIsA",
175            EnumAsGetters => "EnumAsGetters",
176            EnumAsMutGetters => "EnumAsMutGetters",
177            EnumToGetters => "EnumToGetters",
178        }
179    }
180}
181
182/// Generic helper for `EnumIsA` and `EnumAsGetters`.
183/// This generates one function per variant.
184pub fn derive_enum_variant_method(item: TokenStream, method_kind: EnumMethodKind) -> TokenStream {
185    // Parse the input
186    let ast: DeriveInput = parse2(item).unwrap();
187
188    // Generate the code
189    let adt_name = &ast.ident;
190
191    // Generate the code for all the functions in the impl block
192    let Data::Enum(data) = &ast.data else {
193        panic!(
194            "{} macro can only be called on enums",
195            method_kind.variant_name()
196        );
197    };
198    let patterns = generate_variant_match_patterns(&ast.ident, data);
199    let methods: Vec<TokenStream> = patterns
200        .into_iter()
201        .map(|mp| {
202            let pattern = &mp.pattern;
203            let name_prefix = match method_kind {
204                EnumIsA => "is_",
205                EnumAsGetters | EnumAsMutGetters => "as_",
206                EnumToGetters => "to_",
207            };
208            let name_suffix = match method_kind {
209                EnumAsMutGetters => "_mut",
210                _ => "",
211            };
212            let ref_kind = match method_kind {
213                EnumAsGetters | EnumIsA => quote!(&),
214                EnumAsMutGetters => quote!(&mut),
215                EnumToGetters => quote!(),
216            };
217            // TODO: write our own to_snake_case function:
218            // names like "i32" become "i_32" with this one.
219            let variant_name = to_snake_case(&mp.variant_name);
220            let method_name = format!("{name_prefix}{variant_name}{name_suffix}");
221            let method_name = Ident::new(&method_name, Span::call_site());
222            match method_kind {
223                EnumIsA => {
224                    quote!(
225                        pub fn #method_name(#ref_kind self) -> bool {
226                            #[allow(unreachable)]
227                            match self {
228                                #pattern => true,
229                                _ => false,
230                            }
231                        }
232                    )
233                }
234                EnumAsGetters | EnumAsMutGetters | EnumToGetters => {
235                    let vars = &mp.pattern_vars;
236                    let tys = &mp.arg_types;
237                    quote!(
238                        pub fn #method_name(#ref_kind self) -> Option<( #(#ref_kind #tys),* )> {
239                            #[allow(unreachable)]
240                            match self {
241                                #pattern => Some(( #(#vars),* )),
242                                _ => None,
243                            }
244                        }
245                    )
246                }
247            }
248        })
249        .collect();
250
251    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
252    quote!(
253        impl #impl_generics #adt_name #ty_generics #where_clause {
254            #(#methods)*
255        }
256    )
257}
258
259#[test]
260fn test_snake_case() {
261    let s = to_snake_case("ConstantValue");
262    println!("{}", s);
263    assert!(s == "constant_value");
264}
265
266#[test]
267fn test_generics() {
268    let s = quote!(
269        enum Foo<T: Clone>
270        where
271            T: Copy,
272        {
273            Variant1(T),
274            Variant2 { x: u32 },
275            Variant3,
276        }
277    );
278    assert_tokens_eq::assert_tokens_eq!(
279        derive_variant_index_arity(s.clone()),
280        quote! {
281            impl<T: Clone,> Foo<T,>
282            where
283                T: Copy,
284            {
285                pub fn variant_index_arity(&self) -> (u32, usize) {
286                    match self {
287                        Foo::Variant1(_x0) => (0u32, 1usize),
288                        Foo::Variant2 { x: _x0 } => (1u32, 1usize),
289                        Foo::Variant3 => (2u32, 0usize),
290                    }
291                }
292            }
293        }
294    );
295    assert_tokens_eq::assert_tokens_eq!(
296        derive_enum_variant_method(s, EnumAsMutGetters),
297        quote! {
298            impl<T: Clone,> Foo<T,>
299            where
300                T: Copy,
301            {
302                pub fn as_variant1_mut(&mut self) -> (&mut T) {
303                    #[allow(unreachable)]
304                    match self {
305                        Foo::Variant1(_x0) => (_x0),
306                        _ => unreachable!("Foo::as_variant1_mut: Not the proper variant"),
307                    }
308                }
309                pub fn as_variant2_mut(&mut self) -> (&mut u32) {
310                    #[allow(unreachable)]
311                    match self {
312                        Foo::Variant2 { x: _x0 } => (_x0),
313                        _ => unreachable!("Foo::as_variant2_mut: Not the proper variant"),
314                    }
315                }
316                pub fn as_variant3_mut(&mut self) -> () {
317                    #[allow(unreachable)]
318                    match self {
319                        Foo::Variant3 => (),
320                        _ => unreachable!("Foo::as_variant3_mut: Not the proper variant"),
321                    }
322                }
323            }
324        }
325    );
326}