generate_ml/
main.rs

1//! Generate ocaml deserialization code for our types.
2//!
3//! This binary runs charon on itself and generates the appropriate `<type>_of_json` functions for
4//! our types. The generated functions are inserted into `./generate-ml/GAstOfJson.template.ml` to
5//! construct the final `GAstOfJson.ml`.
6//!
7//! To run it, call `cargo run --bin generate-ml`. It is also run by `make generate-ml` in the
8//! crate root. Don't forget to format the output code after regenerating.
9#![feature(if_let_guard)]
10
11use anyhow::{Context, Result, bail};
12use assert_cmd::cargo::CommandCargoExt;
13use charon_lib::ast::*;
14use convert_case::{Case, Casing};
15use indoc::indoc;
16use itertools::Itertools;
17use std::collections::{HashMap, HashSet};
18use std::fmt::Write;
19use std::fs;
20use std::path::PathBuf;
21use std::process::Command;
22
23/// `Name` is a complex datastructure; to inspect it we serialize it a little bit.
24fn repr_name(_crate_data: &TranslatedCrate, n: &Name) -> String {
25    n.name
26        .iter()
27        .map(|path_elem| match path_elem {
28            PathElem::Ident(i, _) => i.clone(),
29            PathElem::Impl(..) => "<impl>".to_string(),
30            PathElem::Monomorphized(..) => "<mono>".to_string(),
31        })
32        .join("::")
33}
34
35fn make_ocaml_ident(name: &str) -> String {
36    let mut name = name.to_case(Case::Snake);
37    if matches!(
38        &*name,
39        "virtual"
40            | "bool"
41            | "char"
42            | "struct"
43            | "type"
44            | "let"
45            | "fun"
46            | "open"
47            | "rec"
48            | "assert"
49            | "float"
50            | "end"
51            | "include"
52            | "to"
53    ) {
54        name += "_";
55    }
56    name
57}
58fn type_name_to_ocaml_ident(item_meta: &ItemMeta) -> String {
59    let name = item_meta
60        .attr_info
61        .rename
62        .as_ref()
63        .unwrap_or(item_meta.name.name.last().unwrap().as_ident().unwrap().0);
64    make_ocaml_ident(name)
65}
66
67struct GenerateCtx<'a> {
68    crate_data: &'a TranslatedCrate,
69    name_to_type: HashMap<String, &'a TypeDecl>,
70    /// For each type, list the types it contains.
71    type_tree: HashMap<TypeDeclId, HashSet<TypeDeclId>>,
72    manual_type_impls: HashMap<TypeDeclId, String>,
73    manual_json_impls: HashMap<TypeDeclId, String>,
74    opaque_for_visitor: HashSet<TypeDeclId>,
75}
76
77impl<'a> GenerateCtx<'a> {
78    fn new(
79        crate_data: &'a TranslatedCrate,
80        manual_type_impls: &[(&str, &str)],
81        manual_json_impls: &[(&str, &str)],
82        opaque_for_visitor: &[&str],
83    ) -> Self {
84        let mut name_to_type: HashMap<String, &TypeDecl> = Default::default();
85        let mut type_tree = HashMap::default();
86        for ty in &crate_data.type_decls {
87            let long_name = repr_name(crate_data, &ty.item_meta.name);
88            if long_name.starts_with("charon_lib") {
89                let short_name = ty
90                    .item_meta
91                    .name
92                    .name
93                    .last()
94                    .unwrap()
95                    .as_ident()
96                    .unwrap()
97                    .0
98                    .clone();
99                name_to_type.insert(short_name, ty);
100            }
101            name_to_type.insert(long_name, ty);
102
103            let mut contained = HashSet::new();
104            ty.dyn_visit(|id: &TypeDeclId| {
105                contained.insert(*id);
106            });
107            type_tree.insert(ty.def_id, contained);
108        }
109
110        let mut ctx = GenerateCtx {
111            crate_data: &crate_data,
112            name_to_type,
113            type_tree,
114            manual_type_impls: Default::default(),
115            manual_json_impls: Default::default(),
116            opaque_for_visitor: Default::default(),
117        };
118        ctx.manual_type_impls = manual_type_impls
119            .iter()
120            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
121            .collect();
122        ctx.manual_json_impls = manual_json_impls
123            .iter()
124            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
125            .collect();
126        ctx.opaque_for_visitor = opaque_for_visitor
127            .iter()
128            .map(|name| ctx.id_from_name(name))
129            .collect();
130        ctx
131    }
132
133    fn id_from_name(&self, name: &str) -> TypeDeclId {
134        self.name_to_type
135            .get(name)
136            .expect(&format!("Name not found: `{name}`"))
137            .def_id
138    }
139
140    /// List the (recursive) children of this type.
141    fn children_of(&self, name: &str) -> HashSet<TypeDeclId> {
142        let start_id = self.id_from_name(name);
143        self.children_of_inner(vec![start_id])
144    }
145
146    /// List the (recursive) children of these types.
147    fn children_of_many(&self, names: &[&str]) -> HashSet<TypeDeclId> {
148        self.children_of_inner(names.iter().map(|name| self.id_from_name(name)).collect())
149    }
150
151    fn children_of_inner(&self, ty: Vec<TypeDeclId>) -> HashSet<TypeDeclId> {
152        let mut children = HashSet::new();
153        let mut stack = ty.to_vec();
154        while let Some(id) = stack.pop() {
155            if !children.contains(&id)
156                && self
157                    .crate_data
158                    .type_decls
159                    .get(id)
160                    .is_some_and(|decl| decl.item_meta.is_local)
161            {
162                children.insert(id);
163                if let Some(contained) = self.type_tree.get(&id) {
164                    stack.extend(contained);
165                }
166            }
167        }
168        children
169    }
170}
171
172/// Converts a type to the appropriate `*_of_json` call. In case of generics, this combines several
173/// functions, e.g. `list_of_json bool_of_json`.
174fn type_to_ocaml_call(ctx: &GenerateCtx, ty: &Ty) -> String {
175    match ty.kind() {
176        TyKind::Literal(LiteralTy::Bool) => "bool_of_json".to_string(),
177        TyKind::Literal(LiteralTy::Char) => "char_of_json".to_string(),
178        TyKind::Literal(LiteralTy::Int(int_ty)) => match int_ty {
179            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
180            IntTy::I128 => "big_int_of_json".to_string(),
181            _ => "int_of_json".to_string(),
182        },
183        TyKind::Literal(LiteralTy::UInt(uint_ty)) => match uint_ty {
184            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
185            UIntTy::U128 => "big_int_of_json".to_string(),
186            _ => "int_of_json".to_string(),
187        },
188        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
189        TyKind::Adt(tref) => {
190            let mut expr = Vec::new();
191            for ty in &tref.generics.types {
192                expr.push(type_to_ocaml_call(ctx, ty))
193            }
194            match tref.id {
195                TypeId::Adt(id) => {
196                    let mut first = if let Some(tdecl) = ctx.crate_data.type_decls.get(id) {
197                        type_name_to_ocaml_ident(&tdecl.item_meta)
198                    } else {
199                        format!("missing_type_{id}")
200                    };
201                    if first == "vec" {
202                        first = "list".to_string();
203                    }
204                    expr.insert(0, first + "_of_json");
205                }
206                TypeId::Builtin(BuiltinTy::Box) => expr.insert(0, "box_of_json".to_owned()),
207                TypeId::Tuple => {
208                    let name = match tref.generics.types.elem_count() {
209                        2 => "pair_of_json".to_string(),
210                        3 => "triple_of_json".to_string(),
211                        len => format!("tuple_{len}_of_json"),
212                    };
213                    expr.insert(0, name);
214                }
215                _ => unimplemented!("{ty:?}"),
216            }
217            expr.into_iter().map(|f| format!("({f})")).join(" ")
218        }
219        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("arg{id}_of_json"),
220        _ => unimplemented!("{ty:?}"),
221    }
222}
223
224/// Converts a type to the appropriate ocaml name. In case of generics, this provides appropriate
225/// parameters.
226fn type_to_ocaml_name(ctx: &GenerateCtx, ty: &Ty) -> String {
227    match ty.kind() {
228        TyKind::Literal(LiteralTy::Bool) => "bool".to_string(),
229        TyKind::Literal(LiteralTy::Char) => "(Uchar.t [@visitors.opaque])".to_string(),
230        TyKind::Literal(LiteralTy::Int(int_ty)) => match int_ty {
231            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
232            IntTy::I128 => "big_int".to_string(),
233            _ => "int".to_string(),
234        },
235        TyKind::Literal(LiteralTy::UInt(uint_ty)) => match uint_ty {
236            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
237            UIntTy::U128 => "big_int".to_string(),
238            _ => "int".to_string(),
239        },
240        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
241        TyKind::Adt(tref) => {
242            let mut args = tref
243                .generics
244                .types
245                .iter()
246                .map(|ty| type_to_ocaml_name(ctx, ty))
247                .map(|name| {
248                    if !name.chars().all(|c| c.is_alphanumeric()) {
249                        format!("({name})")
250                    } else {
251                        name
252                    }
253                })
254                .collect_vec();
255            match tref.id {
256                TypeId::Adt(id) => {
257                    let mut base_ty = if let Some(tdecl) = ctx.crate_data.type_decls.get(id) {
258                        type_name_to_ocaml_ident(&tdecl.item_meta)
259                    } else if let Some(name) = ctx.crate_data.item_name(id) {
260                        eprintln!(
261                            "Warning: type {} missing from llbc",
262                            repr_name(ctx.crate_data, name)
263                        );
264                        name.name
265                            .last()
266                            .unwrap()
267                            .as_ident()
268                            .unwrap()
269                            .0
270                            .to_lowercase()
271                    } else {
272                        format!("missing_type_{id}")
273                    };
274                    if base_ty == "vec" {
275                        base_ty = "list".to_string();
276                    }
277                    if base_ty == "vector" {
278                        base_ty = "list".to_string();
279                        args.remove(0); // Remove the index generic param
280                    }
281                    let args = match args.as_slice() {
282                        [] => String::new(),
283                        [arg] => arg.clone(),
284                        args => format!("({})", args.iter().join(",")),
285                    };
286                    format!("{args} {base_ty}")
287                }
288                TypeId::Builtin(BuiltinTy::Box) => args[0].clone(),
289                TypeId::Tuple => args.iter().join("*"),
290                _ => unimplemented!("{ty:?}"),
291            }
292        }
293        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("'a{id}"),
294        _ => unimplemented!("{ty:?}"),
295    }
296}
297
298fn convert_vars<'a>(ctx: &GenerateCtx, fields: impl IntoIterator<Item = &'a Field>) -> String {
299    fields
300        .into_iter()
301        .filter(|f| !f.is_opaque())
302        .map(|f| {
303            let name = make_ocaml_ident(f.name.as_deref().unwrap());
304            let rename = make_ocaml_ident(f.renamed_name().unwrap());
305            let convert = type_to_ocaml_call(ctx, &f.ty);
306            format!("let* {rename} = {convert} ctx {name} in")
307        })
308        .join("\n")
309}
310
311fn build_branch<'a>(
312    ctx: &GenerateCtx,
313    pat: &str,
314    fields: impl IntoIterator<Item = &'a Field>,
315    construct: &str,
316) -> String {
317    let convert = convert_vars(ctx, fields);
318    format!("| {pat} -> {convert} Ok ({construct})")
319}
320
321fn build_function(_ctx: &GenerateCtx, decl: &TypeDecl, branches: &str) -> String {
322    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
323    let signature = if decl.generics.types.is_empty() {
324        format!("{ty_name}_of_json (ctx : of_json_ctx) (js : json) : ({ty_name}, string) result =")
325    } else {
326        let types = &decl.generics.types;
327        let gen_vars_space = types
328            .iter()
329            .enumerate()
330            .map(|(i, _)| format!("'a{i}"))
331            .join(" ");
332        let gen_vars_comma = types
333            .iter()
334            .enumerate()
335            .map(|(i, _)| format!("'a{i}"))
336            .join(", ");
337
338        let mut args = Vec::new();
339        let mut ty_args = Vec::new();
340        for (i, _) in types.iter().enumerate() {
341            args.push(format!("arg{i}_of_json"));
342            ty_args.push(format!("(of_json_ctx -> json -> ('a{i}, string) result)"));
343        }
344        args.push("ctx".to_string());
345        ty_args.push("of_json_ctx".to_string());
346        args.push("js".to_string());
347        ty_args.push("json".to_string());
348
349        let ty_args = ty_args.into_iter().join(" -> ");
350        let args = args.into_iter().join(" ");
351        let fun_ty =
352            format!("{gen_vars_space}. {ty_args} -> (({gen_vars_comma}) {ty_name}, string) result");
353        format!("{ty_name}_of_json : {fun_ty} = fun {args} ->")
354    };
355    format!(
356        r#"
357        and {signature}
358          combine_error_msgs js __FUNCTION__
359            (match js with{branches} | _ -> Error "")
360        "#
361    )
362}
363
364fn type_decl_to_json_deserializer(ctx: &GenerateCtx, decl: &TypeDecl) -> String {
365    let return_ty = type_name_to_ocaml_ident(&decl.item_meta);
366    let return_ty = if decl.generics.types.is_empty() {
367        return_ty
368    } else {
369        format!("_ {return_ty}")
370    };
371
372    let branches = match &decl.kind {
373        _ if let Some(def) = ctx.manual_json_impls.get(&decl.def_id) => def.clone(),
374        TypeDeclKind::Struct(fields) if fields.is_empty() => {
375            build_branch(ctx, "`Null", fields, "()")
376        }
377        TypeDeclKind::Struct(fields)
378            if fields.elem_count() == 1
379                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
380        {
381            // These are the special strongly-typed integers.
382            let short_name = decl
383                .item_meta
384                .name
385                .name
386                .last()
387                .unwrap()
388                .as_ident()
389                .unwrap()
390                .0
391                .clone();
392            format!("| x -> {short_name}.id_of_json ctx x")
393        }
394        TypeDeclKind::Struct(fields)
395            if fields.elem_count() == 1
396                && (fields[0].name.is_none()
397                    || decl
398                        .item_meta
399                        .attr_info
400                        .attributes
401                        .iter()
402                        .filter_map(|a| a.as_unknown())
403                        .any(|a| a.to_string() == "serde(transparent)")) =>
404        {
405            let ty = &fields[0].ty;
406            let call = type_to_ocaml_call(ctx, ty);
407            format!("| x -> {call} ctx x")
408        }
409        TypeDeclKind::Alias(ty) => {
410            let call = type_to_ocaml_call(ctx, ty);
411            format!("| x -> {call} ctx x")
412        }
413        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => {
414            let mut fields = fields.clone();
415            for (i, f) in fields.iter_mut().enumerate() {
416                f.name = Some(format!("x{i}"));
417            }
418            let pat: String = fields
419                .iter()
420                .map(|f| f.name.as_deref().unwrap())
421                .map(|n| make_ocaml_ident(n))
422                .join(";");
423            let pat = format!("`List [ {pat} ]");
424            let construct = fields
425                .iter()
426                .map(|f| f.renamed_name().unwrap())
427                .map(|n| make_ocaml_ident(n))
428                .join(", ");
429            let construct = format!("( {construct} )");
430            build_branch(ctx, &pat, &fields, &construct)
431        }
432        TypeDeclKind::Struct(fields) => {
433            let fields = fields
434                .iter()
435                .filter(|field| {
436                    !field
437                        .attr_info
438                        .attributes
439                        .iter()
440                        .filter_map(|a| a.as_unknown())
441                        .any(|a| a.to_string() == "serde(skip)")
442                })
443                .collect_vec();
444            let pat: String = fields
445                .iter()
446                .map(|f| {
447                    let name = f.name.as_ref().unwrap();
448                    let var = if f.is_opaque() {
449                        "_"
450                    } else {
451                        &make_ocaml_ident(name)
452                    };
453                    format!("(\"{name}\", {var});")
454                })
455                .join("\n");
456            let pat = format!("`Assoc [ {pat} ]");
457            let construct = fields
458                .iter()
459                .filter(|f| !f.is_opaque())
460                .map(|f| f.renamed_name().unwrap())
461                .map(|n| make_ocaml_ident(n))
462                .join("; ");
463            let construct = format!("({{ {construct} }} : {return_ty})");
464            build_branch(ctx, &pat, fields, &construct)
465        }
466        TypeDeclKind::Enum(variants) => {
467            variants
468                .iter()
469                .filter(|v| !v.is_opaque())
470                .map(|variant| {
471                    let name = &variant.name;
472                    let rename = variant.renamed_name();
473                    if variant.fields.is_empty() {
474                        // Unit variant
475                        let pat = format!("`String \"{name}\"");
476                        build_branch(ctx, &pat, &variant.fields, rename)
477                    } else {
478                        let mut fields = variant.fields.clone();
479                        let inner_pat = if fields.iter().all(|f| f.name.is_none()) {
480                            // Tuple variant
481                            if variant.fields.elem_count() == 1 {
482                                let var = make_ocaml_ident(&variant.name);
483                                fields[0].name = Some(var.clone());
484                                var
485                            } else {
486                                for (i, f) in fields.iter_mut().enumerate() {
487                                    f.name = Some(format!("x_{i}"));
488                                }
489                                let pat =
490                                    fields.iter().map(|f| f.name.as_ref().unwrap()).join("; ");
491                                format!("`List [ {pat} ]")
492                            }
493                        } else {
494                            // Struct variant
495                            let pat = fields
496                                .iter()
497                                .map(|f| {
498                                    let name = f.name.as_ref().unwrap();
499                                    let var = if f.is_opaque() {
500                                        "_"
501                                    } else {
502                                        &make_ocaml_ident(name)
503                                    };
504                                    format!("(\"{name}\", {var});")
505                                })
506                                .join(" ");
507                            format!("`Assoc [ {pat} ]")
508                        };
509                        let pat = format!("`Assoc [ (\"{name}\", {inner_pat}) ]");
510                        let construct_fields = fields
511                            .iter()
512                            .map(|f| f.name.as_ref().unwrap())
513                            .map(|n| make_ocaml_ident(n))
514                            .join(", ");
515                        let construct = format!("{rename} ({construct_fields})");
516                        build_branch(ctx, &pat, &fields, &construct)
517                    }
518                })
519                .join("\n")
520        }
521        TypeDeclKind::Union(..) => todo!(),
522        TypeDeclKind::Opaque => todo!(),
523        TypeDeclKind::Error(_) => todo!(),
524    };
525    build_function(ctx, decl, &branches)
526}
527
528fn extract_doc_comments(attr_info: &AttrInfo) -> String {
529    attr_info
530        .attributes
531        .iter()
532        .filter_map(|a| a.as_doc_comment())
533        .join("\n")
534}
535
536/// Make a doc comment that contains the given string, indenting it if necessary.
537fn build_doc_comment(comment: String, indent_level: usize) -> String {
538    #[derive(Default)]
539    struct Exchanger {
540        is_in_open_escaped_block: bool,
541        is_in_open_inline_escape: bool,
542    }
543    impl Exchanger {
544        pub fn exchange_escape_delimiters(&mut self, line: &str) -> String {
545            if line.contains("```") {
546                // Handle multi-line escaped blocks.
547                let (leading, mut rest) = line.split_once("```").unwrap();
548
549                // Strip all (hard-coded) possible ways to open the block.
550                rest = if rest.starts_with("text") {
551                    rest.strip_prefix("text").unwrap()
552                } else if rest.starts_with("rust,ignore") {
553                    rest.strip_prefix("rust,ignore").unwrap()
554                } else {
555                    rest
556                };
557                let mut result = leading.to_owned();
558                if self.is_in_open_escaped_block {
559                    result.push_str("]}");
560                    self.is_in_open_escaped_block = false;
561                } else {
562                    result.push_str("{@rust[");
563                    self.is_in_open_escaped_block = true;
564                }
565                result.push_str(rest);
566                result
567            } else if line.contains('`') {
568                // Handle inline escaped strings. These can occur in multiple lines, so we track them globally.
569                let mut parts = line.split('`');
570                // Skip after first part (we only need to add escapes after that).
571                let mut result = parts.next().unwrap().to_owned();
572                for part in parts {
573                    if self.is_in_open_inline_escape {
574                        result.push_str("]");
575                        result.push_str(part);
576                        self.is_in_open_inline_escape = false;
577                    } else {
578                        result.push_str("[");
579                        result.push_str(part);
580                        self.is_in_open_inline_escape = true;
581                    }
582                }
583                result
584            } else {
585                line.to_owned()
586            }
587        }
588    }
589
590    if comment == "" {
591        return comment;
592    }
593    let is_multiline = comment.contains("\n");
594    if !is_multiline {
595        let fixed_comment = Exchanger::default().exchange_escape_delimiters(&comment);
596        format!("(**{fixed_comment} *)")
597    } else {
598        let indent = "  ".repeat(indent_level);
599        let mut exchanger = Exchanger::default();
600        let comment = comment
601            .lines()
602            .enumerate()
603            .map(|(i, line)| {
604                // Remove one leading space if there is one (there usually is because we write `///
605                // comment` and not `///comment`).
606                let line = line.strip_prefix(" ").unwrap_or(line);
607                let fixed_line = exchanger.exchange_escape_delimiters(line);
608
609                // The first line follows the `(**` marker, it does not need to be indented.
610                // Neither do empty lines.
611                if i == 0 || fixed_line.is_empty() {
612                    fixed_line
613                } else {
614                    format!("{indent}    {fixed_line}")
615                }
616            })
617            .join("\n");
618        format!("(** {comment}\n{indent} *)")
619    }
620}
621
622fn build_type(_ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool, body: &str) -> String {
623    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
624    let generics = decl
625        .generics
626        .types
627        .iter()
628        .enumerate()
629        .map(|(i, _)| format!("'a{i}"))
630        .collect_vec();
631    let generics = match generics.as_slice() {
632        [] => String::new(),
633        [ty] => ty.clone(),
634        generics => format!("({})", generics.iter().join(",")),
635    };
636    let comment = extract_doc_comments(&decl.item_meta.attr_info);
637    let comment = build_doc_comment(comment, 0);
638    let keyword = if co_rec { "and" } else { "type" };
639    format!("\n{comment} {keyword} {generics} {ty_name} = {body}")
640}
641
642/// Generate an ocaml type declaration that mirrors `decl`.
643///
644/// `co_rec` indicates whether this definition is co-recursive with the ones that come before (i.e.
645/// should be declared with `and` instead of `type`).
646fn type_decl_to_ocaml_decl(ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool) -> String {
647    let opaque = if ctx.opaque_for_visitor.contains(&decl.def_id) {
648        "[@visitors.opaque]"
649    } else {
650        ""
651    };
652    let body = match &decl.kind {
653        _ if let Some(def) = ctx.manual_type_impls.get(&decl.def_id) => def.clone(),
654        TypeDeclKind::Alias(ty) => {
655            let ty = type_to_ocaml_name(ctx, ty);
656            format!("{ty} {opaque}")
657        }
658        TypeDeclKind::Struct(fields) if fields.is_empty() => "unit".to_string(),
659        TypeDeclKind::Struct(fields)
660            if fields.elem_count() == 1
661                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
662        {
663            // These are the special strongly-typed integers.
664            let short_name = decl
665                .item_meta
666                .name
667                .name
668                .last()
669                .unwrap()
670                .as_ident()
671                .unwrap()
672                .0
673                .clone();
674            format!("{short_name}.id [@visitors.opaque]")
675        }
676        TypeDeclKind::Struct(fields)
677            if fields.elem_count() == 1
678                && (fields[0].name.is_none()
679                    || decl
680                        .item_meta
681                        .attr_info
682                        .attributes
683                        .iter()
684                        .filter_map(|a| a.as_unknown())
685                        .any(|a| a.to_string() == "serde(transparent)")) =>
686        {
687            let ty = type_to_ocaml_name(ctx, &fields[0].ty);
688            format!("{ty} {opaque}")
689        }
690        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => fields
691            .iter()
692            .filter(|f| !f.is_opaque())
693            .map(|f| {
694                let ty = type_to_ocaml_name(ctx, &f.ty);
695                format!("{ty} {opaque}")
696            })
697            .join("*"),
698        TypeDeclKind::Struct(fields) => {
699            let fields = fields
700                .iter()
701                .filter(|f| !f.is_opaque())
702                .map(|f| {
703                    let name = f.renamed_name().unwrap();
704                    let ty = type_to_ocaml_name(ctx, &f.ty);
705                    let comment = extract_doc_comments(&f.attr_info);
706                    let comment = build_doc_comment(comment, 2);
707                    format!("{name} : {ty} {opaque} {comment}")
708                })
709                .join(";");
710            format!("{{ {fields} }}")
711        }
712        TypeDeclKind::Enum(variants) => {
713            variants
714                .iter()
715                .filter(|v| !v.is_opaque())
716                .map(|variant| {
717                    let mut attr_info = variant.attr_info.clone();
718                    let rename = variant.renamed_name();
719                    let ty = if variant.fields.is_empty() {
720                        // Unit variant
721                        String::new()
722                    } else {
723                        if variant.fields.iter().all(|f| f.name.is_some()) {
724                            let fields = variant
725                                .fields
726                                .iter()
727                                .map(|f| {
728                                    let comment = extract_doc_comments(&f.attr_info);
729                                    let description = if comment.is_empty() {
730                                        comment
731                                    } else {
732                                        format!(": {comment}")
733                                    };
734                                    format!("\n - [{}]{description}", f.name.as_ref().unwrap())
735                                })
736                                .join("");
737                            let field_descriptions = format!("\n Fields:{fields}");
738                            // Add a constructed doc-comment
739                            attr_info
740                                .attributes
741                                .push(Attribute::DocComment(field_descriptions));
742                        }
743                        let fields = variant
744                            .fields
745                            .iter()
746                            .map(|f| {
747                                let ty = type_to_ocaml_name(ctx, &f.ty);
748                                format!("{ty} {opaque}")
749                            })
750                            .join("*");
751                        format!(" of {fields}")
752                    };
753                    let comment = extract_doc_comments(&attr_info);
754                    let comment = build_doc_comment(comment, 3);
755                    format!("\n\n | {rename}{ty} {comment}")
756                })
757                .join("")
758        }
759        TypeDeclKind::Union(..) => todo!(),
760        TypeDeclKind::Opaque => todo!(),
761        TypeDeclKind::Error(_) => todo!(),
762    };
763    build_type(ctx, decl, co_rec, &body)
764}
765
766fn generate_visitor_bases(
767    _ctx: &GenerateCtx,
768    name: &str,
769    inherits: &[&str],
770    reduce: bool,
771    ty_names: &[String],
772) -> String {
773    let mut out = String::new();
774    let make_inherit = |variety| {
775        if !inherits.is_empty() {
776            inherits
777                .iter()
778                .map(|ancestor| {
779                    if let [module, name] = ancestor.split(".").collect_vec().as_slice() {
780                        format!("inherit [_] {module}.{variety}_{name}")
781                    } else {
782                        format!("inherit [_] {variety}_{ancestor}")
783                    }
784                })
785                .join("\n")
786        } else {
787            format!("inherit [_] VisitorsRuntime.{variety}")
788        }
789    };
790
791    let iter_methods = ty_names
792        .iter()
793        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> unit = fun _ _ -> ()"))
794        .format("\n");
795    let _ = write!(
796        &mut out,
797        "
798        class ['self] iter_{name} =
799          object (self : 'self)
800            {}
801            {iter_methods}
802          end
803        ",
804        make_inherit("iter")
805    );
806
807    let map_methods = ty_names
808        .iter()
809        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> {ty} = fun _ x -> x"))
810        .format("\n");
811    let _ = write!(
812        &mut out,
813        "
814        class ['self] map_{name} =
815          object (self : 'self)
816            {}
817            {map_methods}
818          end
819        ",
820        make_inherit("map")
821    );
822
823    if reduce {
824        let reduce_methods = ty_names
825            .iter()
826            .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> 'a = fun _ _ -> self#zero"))
827            .format("\n");
828        let _ = write!(
829            &mut out,
830            "
831            class virtual ['self] reduce_{name} =
832              object (self : 'self)
833                {}
834                {reduce_methods}
835              end
836            ",
837            make_inherit("reduce")
838        );
839
840        let mapreduce_methods = ty_names
841            .iter()
842            .map(|ty| {
843                format!("method visit_{ty} : 'env -> {ty} -> {ty} * 'a = fun _ x -> (x, self#zero)")
844            })
845            .format("\n");
846        let _ = write!(
847            &mut out,
848            "
849            class virtual ['self] mapreduce_{name} =
850              object (self : 'self)
851                {}
852                {mapreduce_methods}
853              end
854            ",
855            make_inherit("mapreduce")
856        );
857    }
858
859    out
860}
861
862#[derive(Clone, Copy)]
863struct DeriveVisitors {
864    name: &'static str,
865    ancestors: &'static [&'static str],
866    reduce: bool,
867    extra_types: &'static [&'static str],
868}
869
870/// The kind of code generation to perform.
871#[derive(Clone, Copy)]
872enum GenerationKind {
873    OfJson,
874    TypeDecl(Option<DeriveVisitors>),
875}
876
877/// Replace markers in `template` with auto-generated code.
878struct GenerateCodeFor {
879    template: PathBuf,
880    target: PathBuf,
881    /// Each list corresponds to a marker. We replace the ith `__REPLACE{i}__` marker with
882    /// generated code for each definition in the ith list.
883    ///
884    /// Eventually we should reorder definitions so the generated ones are all in one block.
885    /// Keeping the order is important while we migrate away from hand-written code.
886    markers: Vec<(GenerationKind, HashSet<TypeDeclId>)>,
887}
888
889impl GenerateCodeFor {
890    fn generate(&self, ctx: &GenerateCtx) -> Result<()> {
891        let mut template = fs::read_to_string(&self.template)
892            .with_context(|| format!("Failed to read template file {}", self.template.display()))?;
893        for (i, (kind, names)) in self.markers.iter().enumerate() {
894            let tys = names
895                .iter()
896                .map(|&id| &ctx.crate_data[id])
897                .sorted_by_key(|tdecl| {
898                    tdecl
899                        .item_meta
900                        .name
901                        .name
902                        .last()
903                        .unwrap()
904                        .as_ident()
905                        .unwrap()
906                });
907            let generated = match kind {
908                GenerationKind::OfJson => {
909                    let fns = tys
910                        .map(|ty| type_decl_to_json_deserializer(ctx, ty))
911                        .format("\n");
912                    format!("let rec ___ = ()\n{fns}")
913                }
914                GenerationKind::TypeDecl(visitors) => {
915                    let mut decls = tys
916                        .enumerate()
917                        .map(|(i, ty)| {
918                            let co_recursive = i != 0;
919                            type_decl_to_ocaml_decl(ctx, ty, co_recursive)
920                        })
921                        .join("\n");
922                    if let Some(visitors) = visitors {
923                        let &DeriveVisitors {
924                            name,
925                            mut ancestors,
926                            reduce,
927                            extra_types,
928                        } = visitors;
929                        let varieties: &[_] = if reduce {
930                            &["iter", "map", "reduce", "mapreduce"]
931                        } else {
932                            &["iter", "map"]
933                        };
934                        let intermediate_visitor_name;
935                        let intermediate_visitor_name_slice;
936                        if !extra_types.is_empty() {
937                            intermediate_visitor_name = format!("{name}_base");
938                            let intermediate_visitor = generate_visitor_bases(
939                                ctx,
940                                &intermediate_visitor_name,
941                                ancestors,
942                                reduce,
943                                extra_types
944                                    .iter()
945                                    .map(|s| s.to_string())
946                                    .collect_vec()
947                                    .as_slice(),
948                            );
949                            intermediate_visitor_name_slice = [intermediate_visitor_name.as_str()];
950                            ancestors = &intermediate_visitor_name_slice;
951                            decls = format!(
952                                "(* Ancestors for the {name} visitors *){intermediate_visitor}\n{decls}"
953                            );
954                        }
955                        let visitors = varieties
956                            .iter()
957                            .map(|variety| {
958                                let nude = if !ancestors.is_empty() {
959                                    format!("nude = true (* Don't inherit VisitorsRuntime *);")
960                                } else {
961                                    String::new()
962                                };
963                                let ancestors = format!(
964                                    "ancestors = [ {} ];",
965                                    ancestors
966                                        .iter()
967                                        .map(|a| format!("\"{variety}_{a}\""))
968                                        .join(";")
969                                );
970                                format!(
971                                    r#"
972                                    visitors {{
973                                        name = "{variety}_{name}";
974                                        monomorphic = ["env"];
975                                        variety = "{variety}";
976                                        {ancestors}
977                                        {nude}
978                                    }}
979                                "#
980                                )
981                            })
982                            .format(", ");
983                        let _ = write!(&mut decls, "\n[@@deriving show, eq, ord, {visitors}]");
984                    };
985                    decls
986                }
987            };
988            let placeholder = format!("(* __REPLACE{i}__ *)");
989            template = template.replace(&placeholder, &generated);
990        }
991
992        fs::write(&self.target, template)
993            .with_context(|| format!("Failed to write generated file {}", self.target.display()))?;
994        Ok(())
995    }
996}
997
998fn main() -> Result<()> {
999    let dir = PathBuf::from("src/bin/generate-ml");
1000    let charon_llbc = dir.join("charon-itself.ullbc");
1001    let reuse_llbc = std::env::var("CHARON_ML_REUSE_LLBC").is_ok(); // Useful when developping
1002    if !reuse_llbc {
1003        // Call charon on itself
1004        let mut cmd = Command::cargo_bin("charon")?;
1005        cmd.arg("cargo");
1006        cmd.arg("--hide-marker-traits");
1007        cmd.arg("--hide-allocator");
1008        cmd.arg("--ullbc");
1009        cmd.arg("--start-from=charon_lib::ast::krate::TranslatedCrate");
1010        cmd.arg("--start-from=charon_lib::ast::ullbc_ast::BodyContents");
1011        cmd.arg("--exclude=charon_lib::common::hash_consing::HashConsed");
1012        cmd.arg("--dest-file");
1013        cmd.arg(&charon_llbc);
1014        cmd.arg("--");
1015        cmd.arg("--lib");
1016        let output = cmd.output()?;
1017
1018        if !output.status.success() {
1019            let stderr = String::from_utf8(output.stderr.clone())?;
1020            bail!("Compilation failed: {stderr}")
1021        }
1022    }
1023
1024    let crate_data: TranslatedCrate = charon_lib::deserialize_llbc(&charon_llbc)?;
1025    let output_dir = if std::env::var("IN_CI").as_deref() == Ok("1") {
1026        dir.join("generated")
1027    } else {
1028        dir.join("../../../../charon-ml/src/generated")
1029    };
1030    generate_ml(crate_data, dir.join("templates"), output_dir)
1031}
1032
1033fn generate_ml(
1034    crate_data: TranslatedCrate,
1035    template_dir: PathBuf,
1036    output_dir: PathBuf,
1037) -> anyhow::Result<()> {
1038    let manual_type_impls = &[
1039        // Hand-written because we replace the `FileId` with the corresponding file.
1040        ("FileId", "file"),
1041        // Handwritten because we use `indexed_var` as a hack to be able to reuse field names.
1042        // TODO: remove the need for this hack.
1043        ("RegionVar", "(region_id, string option) indexed_var"),
1044        ("TypeVar", "(type_var_id, string) indexed_var"),
1045    ];
1046    let manual_json_impls = &[
1047        // Hand-written because we filter out `None` values.
1048        (
1049            "Vector",
1050            indoc!(
1051                r#"
1052                | js ->
1053                    let* list = list_of_json (option_of_json arg1_of_json) ctx js in
1054                    Ok (List.filter_map (fun x -> x) list)
1055                "#
1056            ),
1057        ),
1058        // Hand-written because we replace the `FileId` with the corresponding file name.
1059        (
1060            "FileId",
1061            indoc!(
1062                r#"
1063                | json ->
1064                    let* file_id = FileId.id_of_json ctx json in
1065                    let file = FileId.Map.find file_id ctx in
1066                    Ok file
1067                "#,
1068            ),
1069        ),
1070    ];
1071    // Types for which we don't want to generate a type at all.
1072    let dont_generate_ty = &[
1073        "ItemOpacity",
1074        "PredicateOrigin",
1075        "TraitTypeConstraintId",
1076        "Ty",
1077        "Vector",
1078    ];
1079    // Types that we don't want visitors to go into.
1080    let opaque_for_visitor = &["Name"];
1081    let ctx = GenerateCtx::new(
1082        &crate_data,
1083        manual_type_impls,
1084        manual_json_impls,
1085        opaque_for_visitor,
1086    );
1087
1088    // Compute the sets of types to be put in each module.
1089    let manually_implemented: HashSet<_> = [
1090        "ItemOpacity",
1091        "PredicateOrigin",
1092        "Ty", // We exclude it since `TyKind` is renamed to `ty`
1093        "Opaque",
1094        "Body",
1095        "FunDecl",
1096        "TranslatedCrate",
1097    ]
1098    .iter()
1099    .map(|name| ctx.id_from_name(name))
1100    .collect();
1101
1102    // Compute type sets for json deserializers.
1103    let (gast_types, llbc_types, ullbc_types) = {
1104        let llbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::llbc_ast::Statement");
1105        let ullbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::ullbc_ast::BodyContents");
1106        let all_types: HashSet<_> = ctx.children_of("TranslatedCrate");
1107
1108        let shared_types: HashSet<_> = llbc_types.intersection(&ullbc_types).copied().collect();
1109        let llbc_types: HashSet<_> = llbc_types.difference(&shared_types).copied().collect();
1110        let ullbc_types: HashSet<_> = ullbc_types.difference(&shared_types).copied().collect();
1111
1112        let body_specific_types: HashSet<_> = llbc_types.union(&ullbc_types).copied().collect();
1113        let gast_types: HashSet<_> = all_types
1114            .difference(&body_specific_types)
1115            .copied()
1116            .collect();
1117
1118        let gast_types: HashSet<_> = gast_types
1119            .difference(&manually_implemented)
1120            .copied()
1121            .collect();
1122        let llbc_types: HashSet<_> = llbc_types
1123            .difference(&manually_implemented)
1124            .copied()
1125            .collect();
1126        let ullbc_types: HashSet<_> = ullbc_types
1127            .difference(&manually_implemented)
1128            .copied()
1129            .collect();
1130        (gast_types, llbc_types, ullbc_types)
1131    };
1132
1133    let mut processed_tys: HashSet<TypeDeclId> = dont_generate_ty
1134        .iter()
1135        .map(|name| ctx.id_from_name(name))
1136        .collect();
1137    // Each call to this will return the children of the listed types that haven't been returned
1138    // yet. By calling it in dependency order, this allows to organize types into files without
1139    // having to list them all.
1140    let mut markers_from_children = |ctx: &GenerateCtx, markers: &[_]| {
1141        markers
1142            .iter()
1143            .copied()
1144            .map(|(kind, type_names)| {
1145                let types: HashSet<_> = ctx.children_of_many(type_names);
1146                let unprocessed_types: HashSet<_> =
1147                    types.difference(&processed_tys).copied().collect();
1148                processed_tys.extend(unprocessed_types.iter().copied());
1149                (kind, unprocessed_types)
1150            })
1151            .collect()
1152    };
1153
1154    #[rustfmt::skip]
1155    let generate_code_for = vec![
1156        GenerateCodeFor {
1157            template: template_dir.join("Meta.ml"),
1158            target: output_dir.join("Generated_Meta.ml"),
1159            markers: markers_from_children(&ctx, &[
1160                (GenerationKind::TypeDecl(None), &[
1161                    "File",
1162                    "Span",
1163                    "AttrInfo",
1164                ]),
1165            ]),
1166        },
1167        GenerateCodeFor {
1168            template: template_dir.join("Values.ml"),
1169            target: output_dir.join("Generated_Values.ml"),
1170            markers: markers_from_children(&ctx, &[
1171                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1172                    ancestors: &["big_int"],
1173                    name: "literal",
1174                    reduce: true,
1175                    extra_types: &[],
1176                })), &[
1177                    "Literal",
1178                    "IntegerTy",
1179                    "LiteralTy",
1180                ]),
1181            ]),
1182        },
1183        GenerateCodeFor {
1184            template: template_dir.join("Types.ml"),
1185            target: output_dir.join("Generated_Types.ml"),
1186            markers: markers_from_children(&ctx, &[
1187                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1188                    ancestors: &["literal"],
1189                    name: "const_generic",
1190                    reduce: true,
1191                    extra_types: &[],
1192                })), &[
1193                    "TypeVarId",
1194                    "ConstGeneric",
1195                    "TraitClauseId",
1196                    "DeBruijnVar",
1197                    "AnyTransId",
1198                ]),
1199                // Can't merge into above because aeneas uses the above alongside their own partial
1200                // copy of `ty`, which causes method type clashes.
1201                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1202                    ancestors: &["ty_base_base"],
1203                    name: "ty",
1204                    reduce: false,
1205                    extra_types: &["span"],
1206                })), &[
1207                    "TyKind",
1208                    "TraitImplRef",
1209                    "FunDeclRef",
1210                    "GlobalDeclRef",
1211                ]),
1212                // TODO: can't merge into above because of field name clashes (`types`, `regions` etc).
1213                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1214                    ancestors: &["ty"],
1215                    name: "type_decl",
1216                    reduce: false,
1217                    extra_types: &[
1218                        "attr_info"
1219                    ],
1220                })), &[
1221                    "Binder",
1222                    "AbortKind",
1223                    "TypeDecl",
1224                ]),
1225            ]),
1226        },
1227        GenerateCodeFor {
1228            template: template_dir.join("Expressions.ml"),
1229            target: output_dir.join("Generated_Expressions.ml"),
1230            markers: markers_from_children(&ctx, &[
1231                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1232                    ancestors: &["type_decl"],
1233                    name: "rvalue",
1234                    reduce: false,
1235                    extra_types: &[],
1236                })), &[
1237                    "Rvalue",
1238                ]),
1239            ]),
1240        },
1241        GenerateCodeFor {
1242            template: template_dir.join("GAst.ml"),
1243            target: output_dir.join("Generated_GAst.ml"),
1244            markers: markers_from_children(&ctx, &[
1245                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1246                    ancestors: &["rvalue"],
1247                    name: "fun_sig",
1248                    reduce: false,
1249                    extra_types: &[],
1250                })), &[
1251                    "Call",
1252                    "Assert",
1253                    "ItemKind",
1254                    "Locals",
1255                    "FunSig",
1256                    "CopyNonOverlapping",
1257                ]),
1258                // These have to be kept separate to avoid field name clashes
1259                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1260                    ancestors: &["fun_sig"],
1261                    name: "global_decl",
1262                    reduce: false,
1263                    extra_types: &[],
1264                })), &[
1265                    "GlobalDecl",
1266                ]),
1267                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1268                    ancestors: &["global_decl"],
1269                    name: "trait_decl",
1270                    reduce: false,
1271                    extra_types: &[],
1272                })), &[
1273                    "TraitDecl",
1274                ]),
1275                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1276                    ancestors: &["trait_decl"],
1277                    name: "trait_impl",
1278                    reduce: false,
1279                    extra_types: &[],
1280                })), &[
1281                    "TraitImpl",
1282                ]),
1283                (GenerationKind::TypeDecl(None), &[
1284                    "CliOpts",
1285                    "GExprBody",
1286                    "DeclarationGroup",
1287                ]),
1288            ]),
1289        },
1290        GenerateCodeFor {
1291            template: template_dir.join("LlbcAst.ml"),
1292            target: output_dir.join("Generated_LlbcAst.ml"),
1293            markers: markers_from_children(&ctx, &[
1294                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1295                    name: "statement_base",
1296                    ancestors: &["trait_impl"],
1297                    reduce: false,
1298                    extra_types: &[],
1299                })), &[
1300                    "charon_lib::ast::llbc_ast::Statement",
1301                ]),
1302            ]),
1303        },
1304        GenerateCodeFor {
1305            template: template_dir.join("UllbcAst.ml"),
1306            target: output_dir.join("Generated_UllbcAst.ml"),
1307            markers: markers_from_children(&ctx, &[
1308                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1309                    ancestors: &["trait_impl"],
1310                    name: "statement",
1311                    reduce: false,
1312                    extra_types: &[],
1313                })), &[
1314                    "charon_lib::ast::ullbc_ast::Statement",
1315                    "charon_lib::ast::ullbc_ast::SwitchTargets",
1316                ]),
1317                // TODO: Can't merge with above because of field name clashes (`content` and `span`).
1318                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1319                    ancestors: &["statement"],
1320                    name: "ullbc_ast",
1321                    reduce: false,
1322                    extra_types: &[],
1323                })), &[
1324                    "charon_lib::ast::ullbc_ast::BodyContents",
1325                ]),
1326            ]),
1327        },
1328        GenerateCodeFor {
1329            template: template_dir.join("GAstOfJson.ml"),
1330            target: output_dir.join("Generated_GAstOfJson.ml"),
1331            markers: vec![(GenerationKind::OfJson, gast_types)],
1332        },
1333        GenerateCodeFor {
1334            template: template_dir.join("LlbcOfJson.ml"),
1335            target: output_dir.join("Generated_LlbcOfJson.ml"),
1336            markers: vec![(GenerationKind::OfJson, llbc_types)],
1337        },
1338        GenerateCodeFor {
1339            template: template_dir.join("UllbcOfJson.ml"),
1340            target: output_dir.join("Generated_UllbcOfJson.ml"),
1341            markers: vec![(GenerationKind::OfJson, ullbc_types)],
1342        },
1343    ];
1344    for file in generate_code_for {
1345        file.generate(&ctx)?;
1346    }
1347    Ok(())
1348}