generate_ml/
main.rs

1//! Generate ocaml deserialization code for our types.
2//!
3//! This binary runs charon on itself and generates the appropriate `<type>_of_json` functions for
4//! our types. The generated functions are inserted into `./generate-ml/GAstOfJson.template.ml` to
5//! construct the final `GAstOfJson.ml`.
6//!
7//! To run it, call `cargo run --bin generate-ml`. It is also run by `make generate-ml` in the
8//! crate root. Don't forget to format the output code after regenerating.
9#![feature(if_let_guard)]
10
11use anyhow::{bail, Context, Result};
12use assert_cmd::cargo::CommandCargoExt;
13use charon_lib::ast::*;
14use convert_case::{Case, Casing};
15use indoc::indoc;
16use itertools::Itertools;
17use std::collections::{HashMap, HashSet};
18use std::fmt::Write;
19use std::fs;
20use std::path::PathBuf;
21use std::process::Command;
22
23/// `Name` is a complex datastructure; to inspect it we serialize it a little bit.
24fn repr_name(_crate_data: &TranslatedCrate, n: &Name) -> String {
25    n.name
26        .iter()
27        .map(|path_elem| match path_elem {
28            PathElem::Ident(i, _) => i.clone(),
29            PathElem::Impl(..) => "<impl>".to_string(),
30            PathElem::Monomorphized(..) => "<mono>".to_string(),
31        })
32        .join("::")
33}
34
35fn make_ocaml_ident(name: &str) -> String {
36    let mut name = name.to_case(Case::Snake);
37    if matches!(
38        &*name,
39        "virtual"
40            | "bool"
41            | "char"
42            | "struct"
43            | "type"
44            | "let"
45            | "fun"
46            | "open"
47            | "rec"
48            | "assert"
49            | "float"
50            | "end"
51            | "include"
52            | "to"
53    ) {
54        name += "_";
55    }
56    name
57}
58fn type_name_to_ocaml_ident(item_meta: &ItemMeta) -> String {
59    let name = item_meta
60        .attr_info
61        .rename
62        .as_ref()
63        .unwrap_or(item_meta.name.name.last().unwrap().as_ident().unwrap().0);
64    make_ocaml_ident(name)
65}
66
67struct GenerateCtx<'a> {
68    crate_data: &'a TranslatedCrate,
69    name_to_type: HashMap<String, &'a TypeDecl>,
70    /// For each type, list the types it contains.
71    type_tree: HashMap<TypeDeclId, HashSet<TypeDeclId>>,
72    manual_type_impls: HashMap<TypeDeclId, String>,
73    manual_json_impls: HashMap<TypeDeclId, String>,
74    opaque_for_visitor: HashSet<TypeDeclId>,
75}
76
77impl<'a> GenerateCtx<'a> {
78    fn new(
79        crate_data: &'a TranslatedCrate,
80        manual_type_impls: &[(&str, &str)],
81        manual_json_impls: &[(&str, &str)],
82        opaque_for_visitor: &[&str],
83    ) -> Self {
84        let mut name_to_type: HashMap<String, &TypeDecl> = Default::default();
85        let mut type_tree = HashMap::default();
86        for ty in &crate_data.type_decls {
87            let long_name = repr_name(crate_data, &ty.item_meta.name);
88            if long_name.starts_with("charon_lib") {
89                let short_name = ty
90                    .item_meta
91                    .name
92                    .name
93                    .last()
94                    .unwrap()
95                    .as_ident()
96                    .unwrap()
97                    .0
98                    .clone();
99                name_to_type.insert(short_name, ty);
100            }
101            name_to_type.insert(long_name, ty);
102
103            let mut contained = HashSet::new();
104            ty.dyn_visit(|id: &TypeDeclId| {
105                contained.insert(*id);
106            });
107            type_tree.insert(ty.def_id, contained);
108        }
109
110        let mut ctx = GenerateCtx {
111            crate_data: &crate_data,
112            name_to_type,
113            type_tree,
114            manual_type_impls: Default::default(),
115            manual_json_impls: Default::default(),
116            opaque_for_visitor: Default::default(),
117        };
118        ctx.manual_type_impls = manual_type_impls
119            .iter()
120            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
121            .collect();
122        ctx.manual_json_impls = manual_json_impls
123            .iter()
124            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
125            .collect();
126        ctx.opaque_for_visitor = opaque_for_visitor
127            .iter()
128            .map(|name| ctx.id_from_name(name))
129            .collect();
130        ctx
131    }
132
133    fn id_from_name(&self, name: &str) -> TypeDeclId {
134        self.name_to_type
135            .get(name)
136            .expect(&format!("Name not found: `{name}`"))
137            .def_id
138    }
139
140    /// List the (recursive) children of this type.
141    fn children_of(&self, name: &str) -> HashSet<TypeDeclId> {
142        let start_id = self.id_from_name(name);
143        self.children_of_inner(vec![start_id])
144    }
145
146    /// List the (recursive) children of these types.
147    fn children_of_many(&self, names: &[&str]) -> HashSet<TypeDeclId> {
148        self.children_of_inner(names.iter().map(|name| self.id_from_name(name)).collect())
149    }
150
151    fn children_of_inner(&self, ty: Vec<TypeDeclId>) -> HashSet<TypeDeclId> {
152        let mut children = HashSet::new();
153        let mut stack = ty.to_vec();
154        while let Some(id) = stack.pop() {
155            if !children.contains(&id)
156                && self
157                    .crate_data
158                    .type_decls
159                    .get(id)
160                    .is_some_and(|decl| decl.item_meta.is_local)
161            {
162                children.insert(id);
163                if let Some(contained) = self.type_tree.get(&id) {
164                    stack.extend(contained);
165                }
166            }
167        }
168        children
169    }
170}
171
172/// Converts a type to the appropriate `*_of_json` call. In case of generics, this combines several
173/// functions, e.g. `list_of_json bool_of_json`.
174fn type_to_ocaml_call(ctx: &GenerateCtx, ty: &Ty) -> String {
175    match ty.kind() {
176        TyKind::Literal(LiteralTy::Bool) => "bool_of_json".to_string(),
177        TyKind::Literal(LiteralTy::Char) => "char_of_json".to_string(),
178        TyKind::Literal(LiteralTy::Integer(_)) => "int_of_json".to_string(),
179        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
180        TyKind::Adt(tref) => {
181            let mut expr = Vec::new();
182            for ty in &tref.generics.types {
183                expr.push(type_to_ocaml_call(ctx, ty))
184            }
185            match tref.id {
186                TypeId::Adt(id) => {
187                    let mut first = if let Some(tdecl) = ctx.crate_data.type_decls.get(id) {
188                        type_name_to_ocaml_ident(&tdecl.item_meta)
189                    } else {
190                        format!("missing_type_{id}")
191                    };
192                    if first == "vec" {
193                        first = "list".to_string();
194                        expr.pop(); // Remove the allocator generic param
195                    }
196                    expr.insert(0, first + "_of_json");
197                }
198                TypeId::Builtin(BuiltinTy::Box) => expr.insert(0, "box_of_json".to_owned()),
199                TypeId::Tuple => {
200                    let name = match tref.generics.types.elem_count() {
201                        2 => "pair_of_json".to_string(),
202                        3 => "triple_of_json".to_string(),
203                        len => format!("tuple_{len}_of_json"),
204                    };
205                    expr.insert(0, name);
206                }
207                _ => unimplemented!("{ty:?}"),
208            }
209            expr.into_iter().map(|f| format!("({f})")).join(" ")
210        }
211        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("arg{id}_of_json"),
212        _ => unimplemented!("{ty:?}"),
213    }
214}
215
216/// Converts a type to the appropriate ocaml name. In case of generics, this provides appropriate
217/// parameters.
218fn type_to_ocaml_name(ctx: &GenerateCtx, ty: &Ty) -> String {
219    match ty.kind() {
220        TyKind::Literal(LiteralTy::Bool) => "bool".to_string(),
221        TyKind::Literal(LiteralTy::Char) => "(Uchar.t [@visitors.opaque])".to_string(),
222        TyKind::Literal(LiteralTy::Integer(_)) => "int".to_string(),
223        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
224        TyKind::Adt(tref) => {
225            let mut args = tref
226                .generics
227                .types
228                .iter()
229                .map(|ty| type_to_ocaml_name(ctx, ty))
230                .map(|name| {
231                    if !name.chars().all(|c| c.is_alphanumeric()) {
232                        format!("({name})")
233                    } else {
234                        name
235                    }
236                })
237                .collect_vec();
238            match tref.id {
239                TypeId::Adt(id) => {
240                    let mut base_ty = if let Some(tdecl) = ctx.crate_data.type_decls.get(id) {
241                        type_name_to_ocaml_ident(&tdecl.item_meta)
242                    } else if let Some(name) = ctx.crate_data.item_name(id) {
243                        eprintln!(
244                            "Warning: type {} missing from llbc",
245                            repr_name(ctx.crate_data, name)
246                        );
247                        name.name
248                            .last()
249                            .unwrap()
250                            .as_ident()
251                            .unwrap()
252                            .0
253                            .to_lowercase()
254                    } else {
255                        format!("missing_type_{id}")
256                    };
257                    if base_ty == "vec" {
258                        base_ty = "list".to_string();
259                        args.pop(); // Remove the allocator generic param
260                    }
261                    if base_ty == "vector" {
262                        base_ty = "list".to_string();
263                        args.remove(0); // Remove the index generic param
264                    }
265                    let args = match args.as_slice() {
266                        [] => String::new(),
267                        [arg] => arg.clone(),
268                        args => format!("({})", args.iter().join(",")),
269                    };
270                    format!("{args} {base_ty}")
271                }
272                TypeId::Builtin(BuiltinTy::Box) => args[0].clone(),
273                TypeId::Tuple => args.iter().join("*"),
274                _ => unimplemented!("{ty:?}"),
275            }
276        }
277        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("'a{id}"),
278        _ => unimplemented!("{ty:?}"),
279    }
280}
281
282fn convert_vars<'a>(ctx: &GenerateCtx, fields: impl IntoIterator<Item = &'a Field>) -> String {
283    fields
284        .into_iter()
285        .filter(|f| !f.is_opaque())
286        .map(|f| {
287            let name = make_ocaml_ident(f.name.as_deref().unwrap());
288            let rename = make_ocaml_ident(f.renamed_name().unwrap());
289            let convert = type_to_ocaml_call(ctx, &f.ty);
290            format!("let* {rename} = {convert} ctx {name} in")
291        })
292        .join("\n")
293}
294
295fn build_branch<'a>(
296    ctx: &GenerateCtx,
297    pat: &str,
298    fields: impl IntoIterator<Item = &'a Field>,
299    construct: &str,
300) -> String {
301    let convert = convert_vars(ctx, fields);
302    format!("| {pat} -> {convert} Ok ({construct})")
303}
304
305fn build_function(_ctx: &GenerateCtx, decl: &TypeDecl, branches: &str) -> String {
306    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
307    let signature = if decl.generics.types.is_empty() {
308        format!("{ty_name}_of_json (ctx : of_json_ctx) (js : json) : ({ty_name}, string) result =")
309    } else {
310        let types = &decl.generics.types;
311        let gen_vars_space = types
312            .iter()
313            .enumerate()
314            .map(|(i, _)| format!("'a{i}"))
315            .join(" ");
316        let gen_vars_comma = types
317            .iter()
318            .enumerate()
319            .map(|(i, _)| format!("'a{i}"))
320            .join(", ");
321
322        let mut args = Vec::new();
323        let mut ty_args = Vec::new();
324        for (i, _) in types.iter().enumerate() {
325            args.push(format!("arg{i}_of_json"));
326            ty_args.push(format!("(of_json_ctx -> json -> ('a{i}, string) result)"));
327        }
328        args.push("ctx".to_string());
329        ty_args.push("of_json_ctx".to_string());
330        args.push("js".to_string());
331        ty_args.push("json".to_string());
332
333        let ty_args = ty_args.into_iter().join(" -> ");
334        let args = args.into_iter().join(" ");
335        let fun_ty =
336            format!("{gen_vars_space}. {ty_args} -> (({gen_vars_comma}) {ty_name}, string) result");
337        format!("{ty_name}_of_json : {fun_ty} = fun {args} ->")
338    };
339    format!(
340        r#"
341        and {signature}
342          combine_error_msgs js __FUNCTION__
343            (match js with{branches} | _ -> Error "")
344        "#
345    )
346}
347
348fn type_decl_to_json_deserializer(ctx: &GenerateCtx, decl: &TypeDecl) -> String {
349    let return_ty = type_name_to_ocaml_ident(&decl.item_meta);
350    let return_ty = if decl.generics.types.is_empty() {
351        return_ty
352    } else {
353        format!("_ {return_ty}")
354    };
355
356    let branches = match &decl.kind {
357        _ if let Some(def) = ctx.manual_json_impls.get(&decl.def_id) => def.clone(),
358        TypeDeclKind::Struct(fields) if fields.is_empty() => {
359            build_branch(ctx, "`Null", fields, "()")
360        }
361        TypeDeclKind::Struct(fields)
362            if fields.elem_count() == 1
363                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
364        {
365            // These are the special strongly-typed integers.
366            let short_name = decl
367                .item_meta
368                .name
369                .name
370                .last()
371                .unwrap()
372                .as_ident()
373                .unwrap()
374                .0
375                .clone();
376            format!("| x -> {short_name}.id_of_json ctx x")
377        }
378        TypeDeclKind::Struct(fields)
379            if fields.elem_count() == 1
380                && (fields[0].name.is_none()
381                    || decl
382                        .item_meta
383                        .attr_info
384                        .attributes
385                        .iter()
386                        .filter_map(|a| a.as_unknown())
387                        .any(|a| a.to_string() == "serde(transparent)")) =>
388        {
389            let ty = &fields[0].ty;
390            let call = type_to_ocaml_call(ctx, ty);
391            format!("| x -> {call} ctx x")
392        }
393        TypeDeclKind::Alias(ty) => {
394            let call = type_to_ocaml_call(ctx, ty);
395            format!("| x -> {call} ctx x")
396        }
397        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => {
398            let mut fields = fields.clone();
399            for (i, f) in fields.iter_mut().enumerate() {
400                f.name = Some(format!("x{i}"));
401            }
402            let pat: String = fields
403                .iter()
404                .map(|f| f.name.as_deref().unwrap())
405                .map(|n| make_ocaml_ident(n))
406                .join(";");
407            let pat = format!("`List [ {pat} ]");
408            let construct = fields
409                .iter()
410                .map(|f| f.renamed_name().unwrap())
411                .map(|n| make_ocaml_ident(n))
412                .join(", ");
413            let construct = format!("( {construct} )");
414            build_branch(ctx, &pat, &fields, &construct)
415        }
416        TypeDeclKind::Struct(fields) => {
417            let fields = fields
418                .iter()
419                .filter(|field| {
420                    !field
421                        .attr_info
422                        .attributes
423                        .iter()
424                        .filter_map(|a| a.as_unknown())
425                        .any(|a| a.to_string() == "serde(skip)")
426                })
427                .collect_vec();
428            let pat: String = fields
429                .iter()
430                .map(|f| {
431                    let name = f.name.as_ref().unwrap();
432                    let var = if f.is_opaque() {
433                        "_"
434                    } else {
435                        &make_ocaml_ident(name)
436                    };
437                    format!("(\"{name}\", {var});")
438                })
439                .join("\n");
440            let pat = format!("`Assoc [ {pat} ]");
441            let construct = fields
442                .iter()
443                .filter(|f| !f.is_opaque())
444                .map(|f| f.renamed_name().unwrap())
445                .map(|n| make_ocaml_ident(n))
446                .join("; ");
447            let construct = format!("({{ {construct} }} : {return_ty})");
448            build_branch(ctx, &pat, fields, &construct)
449        }
450        TypeDeclKind::Enum(variants) => {
451            variants
452                .iter()
453                .filter(|v| !v.is_opaque())
454                .map(|variant| {
455                    let name = &variant.name;
456                    let rename = variant.renamed_name();
457                    if variant.fields.is_empty() {
458                        // Unit variant
459                        let pat = format!("`String \"{name}\"");
460                        build_branch(ctx, &pat, &variant.fields, rename)
461                    } else {
462                        let mut fields = variant.fields.clone();
463                        let inner_pat = if fields.iter().all(|f| f.name.is_none()) {
464                            // Tuple variant
465                            if variant.fields.elem_count() == 1 {
466                                let var = make_ocaml_ident(&variant.name);
467                                fields[0].name = Some(var.clone());
468                                var
469                            } else {
470                                for (i, f) in fields.iter_mut().enumerate() {
471                                    f.name = Some(format!("x_{i}"));
472                                }
473                                let pat =
474                                    fields.iter().map(|f| f.name.as_ref().unwrap()).join("; ");
475                                format!("`List [ {pat} ]")
476                            }
477                        } else {
478                            // Struct variant
479                            let pat = fields
480                                .iter()
481                                .map(|f| {
482                                    let name = f.name.as_ref().unwrap();
483                                    let var = if f.is_opaque() {
484                                        "_"
485                                    } else {
486                                        &make_ocaml_ident(name)
487                                    };
488                                    format!("(\"{name}\", {var});")
489                                })
490                                .join(" ");
491                            format!("`Assoc [ {pat} ]")
492                        };
493                        let pat = format!("`Assoc [ (\"{name}\", {inner_pat}) ]");
494                        let construct_fields = fields
495                            .iter()
496                            .map(|f| f.name.as_ref().unwrap())
497                            .map(|n| make_ocaml_ident(n))
498                            .join(", ");
499                        let construct = format!("{rename} ({construct_fields})");
500                        build_branch(ctx, &pat, &fields, &construct)
501                    }
502                })
503                .join("\n")
504        }
505        TypeDeclKind::Union(..) => todo!(),
506        TypeDeclKind::Opaque => todo!(),
507        TypeDeclKind::Error(_) => todo!(),
508    };
509    build_function(ctx, decl, &branches)
510}
511
512fn extract_doc_comments(attr_info: &AttrInfo) -> String {
513    attr_info
514        .attributes
515        .iter()
516        .filter_map(|a| a.as_doc_comment())
517        .join("\n")
518}
519
520/// Make a doc comment that contains the given string, indenting it if necessary.
521fn build_doc_comment(comment: String, indent_level: usize) -> String {
522    #[derive(Default)]
523    struct Exchanger {
524        is_in_open_escaped_block: bool,
525        is_in_open_inline_escape: bool,
526    }
527    impl Exchanger {
528        pub fn exchange_escape_delimiters(&mut self, line: &str) -> String {
529            if line.contains("```") {
530                // Handle multi-line escaped blocks.
531                let (leading, mut rest) = line.split_once("```").unwrap();
532
533                // Strip all (hard-coded) possible ways to open the block.
534                rest = if rest.starts_with("text") {
535                    rest.strip_prefix("text").unwrap()
536                } else if rest.starts_with("rust,ignore") {
537                    rest.strip_prefix("rust,ignore").unwrap()
538                } else {
539                    rest
540                };
541                let mut result = leading.to_owned();
542                if self.is_in_open_escaped_block {
543                    result.push_str("]}");
544                    self.is_in_open_escaped_block = false;
545                } else {
546                    result.push_str("{@rust[");
547                    self.is_in_open_escaped_block = true;
548                }
549                result.push_str(rest);
550                result
551            } else if line.contains('`') {
552                // Handle inline escaped strings. These can occur in multiple lines, so we track them globally.
553                let mut parts = line.split('`');
554                // Skip after first part (we only need to add escapes after that).
555                let mut result = parts.next().unwrap().to_owned();
556                for part in parts {
557                    if self.is_in_open_inline_escape {
558                        result.push_str("]");
559                        result.push_str(part);
560                        self.is_in_open_inline_escape = false;
561                    } else {
562                        result.push_str("[");
563                        result.push_str(part);
564                        self.is_in_open_inline_escape = true;
565                    }
566                }
567                result
568            } else {
569                line.to_owned()
570            }
571        }
572    }
573
574    if comment == "" {
575        return comment;
576    }
577    let is_multiline = comment.contains("\n");
578    if !is_multiline {
579        let fixed_comment = Exchanger::default().exchange_escape_delimiters(&comment);
580        format!("(**{fixed_comment} *)")
581    } else {
582        let indent = "  ".repeat(indent_level);
583        let mut exchanger = Exchanger::default();
584        let comment = comment
585            .lines()
586            .enumerate()
587            .map(|(i, line)| {
588                // Remove one leading space if there is one (there usually is because we write `///
589                // comment` and not `///comment`).
590                let line = line.strip_prefix(" ").unwrap_or(line);
591                let fixed_line = exchanger.exchange_escape_delimiters(line);
592
593                // The first line follows the `(**` marker, it does not need to be indented.
594                // Neither do empty lines.
595                if i == 0 || fixed_line.is_empty() {
596                    fixed_line
597                } else {
598                    format!("{indent}    {fixed_line}")
599                }
600            })
601            .join("\n");
602        format!("(** {comment}\n{indent} *)")
603    }
604}
605
606fn build_type(_ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool, body: &str) -> String {
607    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
608    let generics = decl
609        .generics
610        .types
611        .iter()
612        .enumerate()
613        .map(|(i, _)| format!("'a{i}"))
614        .collect_vec();
615    let generics = match generics.as_slice() {
616        [] => String::new(),
617        [ty] => ty.clone(),
618        generics => format!("({})", generics.iter().join(",")),
619    };
620    let comment = extract_doc_comments(&decl.item_meta.attr_info);
621    let comment = build_doc_comment(comment, 0);
622    let keyword = if co_rec { "and" } else { "type" };
623    format!("\n{comment} {keyword} {generics} {ty_name} = {body}")
624}
625
626/// Generate an ocaml type declaration that mirrors `decl`.
627///
628/// `co_rec` indicates whether this definition is co-recursive with the ones that come before (i.e.
629/// should be declared with `and` instead of `type`).
630fn type_decl_to_ocaml_decl(ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool) -> String {
631    let opaque = if ctx.opaque_for_visitor.contains(&decl.def_id) {
632        "[@visitors.opaque]"
633    } else {
634        ""
635    };
636    let body = match &decl.kind {
637        _ if let Some(def) = ctx.manual_type_impls.get(&decl.def_id) => def.clone(),
638        TypeDeclKind::Alias(ty) => {
639            let ty = type_to_ocaml_name(ctx, ty);
640            format!("{ty} {opaque}")
641        }
642        TypeDeclKind::Struct(fields) if fields.is_empty() => "unit".to_string(),
643        TypeDeclKind::Struct(fields)
644            if fields.elem_count() == 1
645                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
646        {
647            // These are the special strongly-typed integers.
648            let short_name = decl
649                .item_meta
650                .name
651                .name
652                .last()
653                .unwrap()
654                .as_ident()
655                .unwrap()
656                .0
657                .clone();
658            format!("{short_name}.id [@visitors.opaque]")
659        }
660        TypeDeclKind::Struct(fields)
661            if fields.elem_count() == 1
662                && (fields[0].name.is_none()
663                    || decl
664                        .item_meta
665                        .attr_info
666                        .attributes
667                        .iter()
668                        .filter_map(|a| a.as_unknown())
669                        .any(|a| a.to_string() == "serde(transparent)")) =>
670        {
671            let ty = type_to_ocaml_name(ctx, &fields[0].ty);
672            format!("{ty} {opaque}")
673        }
674        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => fields
675            .iter()
676            .filter(|f| !f.is_opaque())
677            .map(|f| {
678                let ty = type_to_ocaml_name(ctx, &f.ty);
679                format!("{ty} {opaque}")
680            })
681            .join("*"),
682        TypeDeclKind::Struct(fields) => {
683            let fields = fields
684                .iter()
685                .filter(|f| !f.is_opaque())
686                .map(|f| {
687                    let name = f.renamed_name().unwrap();
688                    let ty = type_to_ocaml_name(ctx, &f.ty);
689                    let comment = extract_doc_comments(&f.attr_info);
690                    let comment = build_doc_comment(comment, 2);
691                    format!("{name} : {ty} {opaque} {comment}")
692                })
693                .join(";");
694            format!("{{ {fields} }}")
695        }
696        TypeDeclKind::Enum(variants) => {
697            variants
698                .iter()
699                .filter(|v| !v.is_opaque())
700                .map(|variant| {
701                    let mut attr_info = variant.attr_info.clone();
702                    let rename = variant.renamed_name();
703                    let ty = if variant.fields.is_empty() {
704                        // Unit variant
705                        String::new()
706                    } else {
707                        if variant.fields.iter().all(|f| f.name.is_some()) {
708                            let fields = variant
709                                .fields
710                                .iter()
711                                .map(|f| {
712                                    let comment = extract_doc_comments(&f.attr_info);
713                                    let description = if comment.is_empty() {
714                                        comment
715                                    } else {
716                                        format!(": {comment}")
717                                    };
718                                    format!("\n - [{}]{description}", f.name.as_ref().unwrap())
719                                })
720                                .join("");
721                            let field_descriptions = format!("\n Fields:{fields}");
722                            // Add a constructed doc-comment
723                            attr_info
724                                .attributes
725                                .push(Attribute::DocComment(field_descriptions));
726                        }
727                        let fields = variant
728                            .fields
729                            .iter()
730                            .map(|f| {
731                                let ty = type_to_ocaml_name(ctx, &f.ty);
732                                format!("{ty} {opaque}")
733                            })
734                            .join("*");
735                        format!(" of {fields}")
736                    };
737                    let comment = extract_doc_comments(&attr_info);
738                    let comment = build_doc_comment(comment, 3);
739                    format!("\n\n | {rename}{ty} {comment}")
740                })
741                .join("")
742        }
743        TypeDeclKind::Union(..) => todo!(),
744        TypeDeclKind::Opaque => todo!(),
745        TypeDeclKind::Error(_) => todo!(),
746    };
747    build_type(ctx, decl, co_rec, &body)
748}
749
750fn generate_visitor_bases(
751    _ctx: &GenerateCtx,
752    name: &str,
753    inherits: &[&str],
754    reduce: bool,
755    ty_names: &[String],
756) -> String {
757    let mut out = String::new();
758    let make_inherit = |variety| {
759        if !inherits.is_empty() {
760            inherits
761                .iter()
762                .map(|ancestor| {
763                    if let [module, name] = ancestor.split(".").collect_vec().as_slice() {
764                        format!("inherit [_] {module}.{variety}_{name}")
765                    } else {
766                        format!("inherit [_] {variety}_{ancestor}")
767                    }
768                })
769                .join("\n")
770        } else {
771            format!("inherit [_] VisitorsRuntime.{variety}")
772        }
773    };
774
775    let iter_methods = ty_names
776        .iter()
777        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> unit = fun _ _ -> ()"))
778        .format("\n");
779    let _ = write!(
780        &mut out,
781        "
782        class ['self] iter_{name} =
783          object (self : 'self)
784            {}
785            {iter_methods}
786          end
787        ",
788        make_inherit("iter")
789    );
790
791    let map_methods = ty_names
792        .iter()
793        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> {ty} = fun _ x -> x"))
794        .format("\n");
795    let _ = write!(
796        &mut out,
797        "
798        class ['self] map_{name} =
799          object (self : 'self)
800            {}
801            {map_methods}
802          end
803        ",
804        make_inherit("map")
805    );
806
807    if reduce {
808        let reduce_methods = ty_names
809            .iter()
810            .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> 'a = fun _ _ -> self#zero"))
811            .format("\n");
812        let _ = write!(
813            &mut out,
814            "
815            class virtual ['self] reduce_{name} =
816              object (self : 'self)
817                {}
818                {reduce_methods}
819              end
820            ",
821            make_inherit("reduce")
822        );
823
824        let mapreduce_methods = ty_names
825            .iter()
826            .map(|ty| {
827                format!("method visit_{ty} : 'env -> {ty} -> {ty} * 'a = fun _ x -> (x, self#zero)")
828            })
829            .format("\n");
830        let _ = write!(
831            &mut out,
832            "
833            class virtual ['self] mapreduce_{name} =
834              object (self : 'self)
835                {}
836                {mapreduce_methods}
837              end
838            ",
839            make_inherit("mapreduce")
840        );
841    }
842
843    out
844}
845
846#[derive(Clone, Copy)]
847struct DeriveVisitors {
848    name: &'static str,
849    ancestors: &'static [&'static str],
850    reduce: bool,
851    extra_types: &'static [&'static str],
852}
853
854/// The kind of code generation to perform.
855#[derive(Clone, Copy)]
856enum GenerationKind {
857    OfJson,
858    TypeDecl(Option<DeriveVisitors>),
859}
860
861/// Replace markers in `template` with auto-generated code.
862struct GenerateCodeFor {
863    template: PathBuf,
864    target: PathBuf,
865    /// Each list corresponds to a marker. We replace the ith `__REPLACE{i}__` marker with
866    /// generated code for each definition in the ith list.
867    ///
868    /// Eventually we should reorder definitions so the generated ones are all in one block.
869    /// Keeping the order is important while we migrate away from hand-written code.
870    markers: Vec<(GenerationKind, HashSet<TypeDeclId>)>,
871}
872
873impl GenerateCodeFor {
874    fn generate(&self, ctx: &GenerateCtx) -> Result<()> {
875        let mut template = fs::read_to_string(&self.template)
876            .with_context(|| format!("Failed to read template file {}", self.template.display()))?;
877        for (i, (kind, names)) in self.markers.iter().enumerate() {
878            let tys = names
879                .iter()
880                .map(|&id| &ctx.crate_data[id])
881                .sorted_by_key(|tdecl| {
882                    tdecl
883                        .item_meta
884                        .name
885                        .name
886                        .last()
887                        .unwrap()
888                        .as_ident()
889                        .unwrap()
890                });
891            let generated = match kind {
892                GenerationKind::OfJson => {
893                    let fns = tys
894                        .map(|ty| type_decl_to_json_deserializer(ctx, ty))
895                        .format("\n");
896                    format!("let rec ___ = ()\n{fns}")
897                }
898                GenerationKind::TypeDecl(visitors) => {
899                    let mut decls = tys
900                        .enumerate()
901                        .map(|(i, ty)| {
902                            let co_recursive = i != 0;
903                            type_decl_to_ocaml_decl(ctx, ty, co_recursive)
904                        })
905                        .join("\n");
906                    if let Some(visitors) = visitors {
907                        let DeriveVisitors {
908                            name,
909                            mut ancestors,
910                            reduce,
911                            extra_types,
912                        } = visitors;
913                        let varieties: &[_] = if *reduce {
914                            &["iter", "map", "reduce", "mapreduce"]
915                        } else {
916                            &["iter", "map"]
917                        };
918                        let intermediate_visitor_name;
919                        let intermediate_visitor_name_slice;
920                        if !extra_types.is_empty() {
921                            intermediate_visitor_name = format!("{name}_base");
922                            let intermediate_visitor = generate_visitor_bases(
923                                ctx,
924                                &intermediate_visitor_name,
925                                ancestors,
926                                *reduce,
927                                extra_types
928                                    .iter()
929                                    .map(|s| s.to_string())
930                                    .collect_vec()
931                                    .as_slice(),
932                            );
933                            intermediate_visitor_name_slice = [intermediate_visitor_name.as_str()];
934                            ancestors = &intermediate_visitor_name_slice;
935                            decls = format!("(* Ancestors for the {name} visitors *){intermediate_visitor}\n{decls}");
936                        }
937                        let visitors = varieties
938                            .iter()
939                            .map(|variety| {
940                                let nude = if !ancestors.is_empty() {
941                                    format!("nude = true (* Don't inherit VisitorsRuntime *);")
942                                } else {
943                                    String::new()
944                                };
945                                let ancestors = format!(
946                                    "ancestors = [ {} ];",
947                                    ancestors
948                                        .iter()
949                                        .map(|a| format!("\"{variety}_{a}\""))
950                                        .join(";")
951                                );
952                                format!(
953                                    r#"
954                                    visitors {{
955                                        name = "{variety}_{name}";
956                                        monomorphic = ["env"];
957                                        variety = "{variety}";
958                                        {ancestors}
959                                        {nude}
960                                    }}
961                                "#
962                                )
963                            })
964                            .format(", ");
965                        let _ = write!(&mut decls, "\n[@@deriving show, eq, ord, {visitors}]");
966                    };
967                    decls
968                }
969            };
970            let placeholder = format!("(* __REPLACE{i}__ *)");
971            template = template.replace(&placeholder, &generated);
972        }
973
974        fs::write(&self.target, template)
975            .with_context(|| format!("Failed to write generated file {}", self.target.display()))?;
976        Ok(())
977    }
978}
979
980fn main() -> Result<()> {
981    let dir = PathBuf::from("src/bin/generate-ml");
982    let charon_llbc = dir.join("charon-itself.ullbc");
983    let reuse_llbc = std::env::var("CHARON_ML_REUSE_LLBC").is_ok(); // Useful when developping
984    if !reuse_llbc {
985        // Call charon on itself
986        let mut cmd = Command::cargo_bin("charon")?;
987        cmd.arg("cargo");
988        cmd.arg("--hide-marker-traits");
989        cmd.arg("--ullbc");
990        cmd.arg("--start-from=charon_lib::ast::krate::TranslatedCrate");
991        cmd.arg("--start-from=charon_lib::ast::ullbc_ast::BodyContents");
992        cmd.arg("--exclude=charon_lib::common::hash_consing::HashConsed");
993        cmd.arg("--dest-file");
994        cmd.arg(&charon_llbc);
995        cmd.arg("--");
996        cmd.arg("--lib");
997        let output = cmd.output()?;
998
999        if !output.status.success() {
1000            let stderr = String::from_utf8(output.stderr.clone())?;
1001            bail!("Compilation failed: {stderr}")
1002        }
1003    }
1004
1005    let crate_data: TranslatedCrate = charon_lib::deserialize_llbc(&charon_llbc)?;
1006    let output_dir = if std::env::var("IN_CI").as_deref() == Ok("1") {
1007        dir.join("generated")
1008    } else {
1009        dir.join("../../../../charon-ml/src/generated")
1010    };
1011    generate_ml(crate_data, dir.join("templates"), output_dir)
1012}
1013
1014fn generate_ml(
1015    crate_data: TranslatedCrate,
1016    template_dir: PathBuf,
1017    output_dir: PathBuf,
1018) -> anyhow::Result<()> {
1019    let manual_type_impls = &[
1020        // Hand-written because we replace the `FileId` with the corresponding file.
1021        ("FileId", "file"),
1022        // Hand-written because the rust version is an enum with custom (de)serialization
1023        // functions.
1024        (
1025            "ScalarValue",
1026            indoc!(
1027                "
1028                (* Note that we use unbounded integers everywhere.
1029                   We then harcode the boundaries for the different types.
1030                 *)
1031                { value : big_int; int_ty : integer_type }
1032                "
1033            ),
1034        ),
1035        // Handwritten because we use `indexed_var` as a hack to be able to reuse field names.
1036        // TODO: remove the need for this hack.
1037        ("RegionVar", "(region_id, string option) indexed_var"),
1038        ("TypeVar", "(type_var_id, string) indexed_var"),
1039    ];
1040    let manual_json_impls = &[
1041        // Hand-written because we filter out `None` values.
1042        (
1043            "Vector",
1044            indoc!(
1045                r#"
1046                | js ->
1047                    let* list = list_of_json (option_of_json arg1_of_json) ctx js in
1048                    Ok (List.filter_map (fun x -> x) list)
1049                "#
1050            ),
1051        ),
1052        // Hand-written because we replace the `FileId` with the corresponding file name.
1053        (
1054            "FileId",
1055            indoc!(
1056                r#"
1057                | json ->
1058                    let* file_id = FileId.id_of_json ctx json in
1059                    let file = FileId.Map.find file_id ctx in
1060                    Ok file
1061                "#,
1062            ),
1063        ),
1064        // Hand-written because the rust version is an enum with custom (de)serialization
1065        // functions.
1066        (
1067            "ScalarValue",
1068            indoc!(
1069                r#"
1070                | `Assoc [ (ty, bi) ] ->
1071                    let big_int_of_json (js : json) : (big_int, string) result =
1072                      combine_error_msgs js __FUNCTION__
1073                        (match js with
1074                        | `Int i -> Ok (Z.of_int i)
1075                        | `String is -> Ok (Z.of_string is)
1076                        | _ -> Error "")
1077                    in
1078                    let* value = big_int_of_json bi in
1079                    let* int_ty = integer_type_of_json ctx (`String ty) in
1080                    let sv = { value; int_ty } in
1081                    if not (check_scalar_value_in_range sv) then
1082                      raise (Failure ("Scalar value not in range: " ^ show_scalar_value sv));
1083                    Ok sv
1084                "#
1085            ),
1086        ),
1087    ];
1088    // Types for which we don't want to generate a type at all.
1089    let dont_generate_ty = &[
1090        "ItemOpacity",
1091        "PredicateOrigin",
1092        "TraitTypeConstraintId",
1093        "Ty",
1094        "Vector",
1095    ];
1096    // Types that we don't want visitors to go into.
1097    let opaque_for_visitor = &["Name"];
1098    let ctx = GenerateCtx::new(
1099        &crate_data,
1100        manual_type_impls,
1101        manual_json_impls,
1102        opaque_for_visitor,
1103    );
1104
1105    // Compute the sets of types to be put in each module.
1106    let manually_implemented: HashSet<_> = [
1107        "ItemOpacity",
1108        "PredicateOrigin",
1109        "Ty", // We exclude it since `TyKind` is renamed to `ty`
1110        "Opaque",
1111        "Body",
1112        "FunDecl",
1113        "TranslatedCrate",
1114    ]
1115    .iter()
1116    .map(|name| ctx.id_from_name(name))
1117    .collect();
1118
1119    // Compute type sets for json deserializers.
1120    let (gast_types, llbc_types, ullbc_types) = {
1121        let llbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::llbc_ast::Statement");
1122        let ullbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::ullbc_ast::BodyContents");
1123        let all_types: HashSet<_> = ctx.children_of("TranslatedCrate");
1124
1125        let shared_types: HashSet<_> = llbc_types.intersection(&ullbc_types).copied().collect();
1126        let llbc_types: HashSet<_> = llbc_types.difference(&shared_types).copied().collect();
1127        let ullbc_types: HashSet<_> = ullbc_types.difference(&shared_types).copied().collect();
1128
1129        let body_specific_types: HashSet<_> = llbc_types.union(&ullbc_types).copied().collect();
1130        let gast_types: HashSet<_> = all_types
1131            .difference(&body_specific_types)
1132            .copied()
1133            .collect();
1134
1135        let gast_types: HashSet<_> = gast_types
1136            .difference(&manually_implemented)
1137            .copied()
1138            .collect();
1139        let llbc_types: HashSet<_> = llbc_types
1140            .difference(&manually_implemented)
1141            .copied()
1142            .collect();
1143        let ullbc_types: HashSet<_> = ullbc_types
1144            .difference(&manually_implemented)
1145            .copied()
1146            .collect();
1147        (gast_types, llbc_types, ullbc_types)
1148    };
1149
1150    let mut processed_tys: HashSet<TypeDeclId> = dont_generate_ty
1151        .iter()
1152        .map(|name| ctx.id_from_name(name))
1153        .collect();
1154    // Each call to this will return the children of the listed types that haven't been returned
1155    // yet. By calling it in dependency order, this allows to organize types into files without
1156    // having to list them all.
1157    let mut markers_from_children = |ctx: &GenerateCtx, markers: &[_]| {
1158        markers
1159            .iter()
1160            .copied()
1161            .map(|(kind, type_names)| {
1162                let types: HashSet<_> = ctx.children_of_many(type_names);
1163                let unprocessed_types: HashSet<_> =
1164                    types.difference(&processed_tys).copied().collect();
1165                processed_tys.extend(unprocessed_types.iter().copied());
1166                (kind, unprocessed_types)
1167            })
1168            .collect()
1169    };
1170
1171    #[rustfmt::skip]
1172    let generate_code_for = vec![
1173        GenerateCodeFor {
1174            template: template_dir.join("Meta.ml"),
1175            target: output_dir.join("Generated_Meta.ml"),
1176            markers: markers_from_children(&ctx, &[
1177                (GenerationKind::TypeDecl(None), &[
1178                    "File",
1179                    "Span",
1180                    "AttrInfo",
1181                ]),
1182            ]),
1183        },
1184        GenerateCodeFor {
1185            template: template_dir.join("Values.ml"),
1186            target: output_dir.join("Generated_Values.ml"),
1187            markers: markers_from_children(&ctx, &[
1188                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1189                    ancestors: &["big_int"],
1190                    name: "literal",
1191                    reduce: true,
1192                    extra_types: &[],
1193                })), &[
1194                    "Literal",
1195                    "IntegerTy",
1196                    "LiteralTy",
1197                ]),
1198            ]),
1199        },
1200        GenerateCodeFor {
1201            template: template_dir.join("Types.ml"),
1202            target: output_dir.join("Generated_Types.ml"),
1203            markers: markers_from_children(&ctx, &[
1204                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1205                    ancestors: &["literal"],
1206                    name: "const_generic",
1207                    reduce: true,
1208                    extra_types: &[],
1209                })), &[
1210                    "TypeVarId",
1211                    "ConstGeneric",
1212                    "TraitClauseId",
1213                    "DeBruijnVar",
1214                    "AnyTransId",
1215                ]),
1216                // Can't merge into above because aeneas uses the above alongside their own partial
1217                // copy of `ty`, which causes method type clashes.
1218                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1219                    ancestors: &["ty_base_base"],
1220                    name: "ty",
1221                    reduce: false,
1222                    extra_types: &[],
1223                })), &[
1224                    "TyKind",
1225                    "TraitImplRef",
1226                    "FunDeclRef",
1227                    "GlobalDeclRef",
1228                ]),
1229                // TODO: can't merge into above because of field name clashes (`types`, `regions` etc).
1230                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1231                    ancestors: &["ty"],
1232                    name: "type_decl",
1233                    reduce: false,
1234                    extra_types: &[
1235                        "span", "attr_info"
1236                    ],
1237                })), &[
1238                    "Binder",
1239                    "AbortKind",
1240                    "TypeDecl",
1241                ]),
1242            ]),
1243        },
1244        GenerateCodeFor {
1245            template: template_dir.join("Expressions.ml"),
1246            target: output_dir.join("Generated_Expressions.ml"),
1247            markers: markers_from_children(&ctx, &[
1248                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1249                    ancestors: &["type_decl"],
1250                    name: "rvalue",
1251                    reduce: false,
1252                    extra_types: &[],
1253                })), &[
1254                    "Rvalue",
1255                ]),
1256            ]),
1257        },
1258        GenerateCodeFor {
1259            template: template_dir.join("GAst.ml"),
1260            target: output_dir.join("Generated_GAst.ml"),
1261            markers: markers_from_children(&ctx, &[
1262                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1263                    ancestors: &["rvalue"],
1264                    name: "fun_sig",
1265                    reduce: false,
1266                    extra_types: &[],
1267                })), &[
1268                    "Call",
1269                    "Assert",
1270                    "ItemKind",
1271                    "Locals",
1272                    "FunSig",
1273                    "CopyNonOverlapping",
1274                ]),
1275                // These have to be kept separate to avoid field name clashes
1276                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1277                    ancestors: &["fun_sig"],
1278                    name: "global_decl",
1279                    reduce: false,
1280                    extra_types: &[],
1281                })), &[
1282                    "GlobalDecl",
1283                ]),
1284                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1285                    ancestors: &["global_decl"],
1286                    name: "trait_decl",
1287                    reduce: false,
1288                    extra_types: &[],
1289                })), &[
1290                    "TraitDecl",
1291                ]),
1292                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1293                    ancestors: &["trait_decl"],
1294                    name: "trait_impl",
1295                    reduce: false,
1296                    extra_types: &[],
1297                })), &[
1298                    "TraitImpl",
1299                ]),
1300                (GenerationKind::TypeDecl(None), &[
1301                    "CliOpts",
1302                    "GExprBody",
1303                    "DeclarationGroup",
1304                ]),
1305            ]),
1306        },
1307        GenerateCodeFor {
1308            template: template_dir.join("LlbcAst.ml"),
1309            target: output_dir.join("Generated_LlbcAst.ml"),
1310            markers: markers_from_children(&ctx, &[
1311                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1312                    name: "statement_base",
1313                    ancestors: &["trait_impl"],
1314                    reduce: false,
1315                    extra_types: &[],
1316                })), &[
1317                    "charon_lib::ast::llbc_ast::Statement",
1318                ]),
1319            ]),
1320        },
1321        GenerateCodeFor {
1322            template: template_dir.join("UllbcAst.ml"),
1323            target: output_dir.join("Generated_UllbcAst.ml"),
1324            markers: markers_from_children(&ctx, &[
1325                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1326                    ancestors: &["trait_impl"],
1327                    name: "statement",
1328                    reduce: false,
1329                    extra_types: &[],
1330                })), &[
1331                    "charon_lib::ast::ullbc_ast::Statement",
1332                    "charon_lib::ast::ullbc_ast::SwitchTargets",
1333                ]),
1334                // TODO: Can't merge with above because of field name clashes (`content` and `span`).
1335                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1336                    ancestors: &["statement"],
1337                    name: "ullbc_ast",
1338                    reduce: false,
1339                    extra_types: &[],
1340                })), &[
1341                    "charon_lib::ast::ullbc_ast::BodyContents",
1342                ]),
1343            ]),
1344        },
1345        GenerateCodeFor {
1346            template: template_dir.join("GAstOfJson.ml"),
1347            target: output_dir.join("Generated_GAstOfJson.ml"),
1348            markers: vec![(GenerationKind::OfJson, gast_types)],
1349        },
1350        GenerateCodeFor {
1351            template: template_dir.join("LlbcOfJson.ml"),
1352            target: output_dir.join("Generated_LlbcOfJson.ml"),
1353            markers: vec![(GenerationKind::OfJson, llbc_types)],
1354        },
1355        GenerateCodeFor {
1356            template: template_dir.join("UllbcOfJson.ml"),
1357            target: output_dir.join("Generated_UllbcOfJson.ml"),
1358            markers: vec![(GenerationKind::OfJson, ullbc_types)],
1359        },
1360    ];
1361    for file in generate_code_for {
1362        file.generate(&ctx)?;
1363    }
1364    Ok(())
1365}