generate_ml/
main.rs

1//! Generate ocaml deserialization code for our types.
2//!
3//! This binary runs charon on itself and generates the appropriate `<type>_of_json` functions for
4//! our types. The generated functions are inserted into `./generate-ml/GAstOfJson.template.ml` to
5//! construct the final `GAstOfJson.ml`.
6//!
7//! To run it, call `cargo run --bin generate-ml`. It is also run by `make generate-ml` in the
8//! crate root. Don't forget to format the output code after regenerating.
9#![feature(if_let_guard)]
10
11use anyhow::{Context, Result, bail};
12use assert_cmd::cargo::CommandCargoExt;
13use charon_lib::ast::*;
14use convert_case::{Case, Casing};
15use indoc::indoc;
16use itertools::Itertools;
17use std::collections::{HashMap, HashSet};
18use std::fmt::Write;
19use std::fs;
20use std::path::PathBuf;
21use std::process::Command;
22
23/// `Name` is a complex datastructure; to inspect it we serialize it a little bit.
24fn repr_name(_crate_data: &TranslatedCrate, n: &Name) -> String {
25    n.name
26        .iter()
27        .map(|path_elem| match path_elem {
28            PathElem::Ident(i, _) => i.clone(),
29            PathElem::Impl(..) => "<impl>".to_string(),
30            PathElem::Instantiated(..) => "<mono>".to_string(),
31        })
32        .join("::")
33}
34
35fn make_ocaml_ident(name: &str) -> String {
36    let mut name = name.to_case(Case::Snake);
37    if matches!(
38        &*name,
39        "virtual"
40            | "bool"
41            | "char"
42            | "struct"
43            | "type"
44            | "let"
45            | "fun"
46            | "open"
47            | "rec"
48            | "assert"
49            | "float"
50            | "end"
51            | "include"
52            | "to"
53    ) {
54        name += "_";
55    }
56    name
57}
58fn type_name_to_ocaml_ident(item_meta: &ItemMeta) -> String {
59    let name = item_meta
60        .attr_info
61        .rename
62        .as_ref()
63        .unwrap_or(item_meta.name.name.last().unwrap().as_ident().unwrap().0);
64    make_ocaml_ident(name)
65}
66
67struct GenerateCtx<'a> {
68    crate_data: &'a TranslatedCrate,
69    name_to_type: HashMap<String, &'a TypeDecl>,
70    /// For each type, list the types it contains.
71    type_tree: HashMap<TypeDeclId, HashSet<TypeDeclId>>,
72    manual_type_impls: HashMap<TypeDeclId, String>,
73    manual_json_impls: HashMap<TypeDeclId, String>,
74    opaque_for_visitor: HashSet<TypeDeclId>,
75}
76
77impl<'a> GenerateCtx<'a> {
78    fn new(
79        crate_data: &'a TranslatedCrate,
80        manual_type_impls: &[(&str, &str)],
81        manual_json_impls: &[(&str, &str)],
82        opaque_for_visitor: &[&str],
83    ) -> Self {
84        let mut name_to_type: HashMap<String, &TypeDecl> = Default::default();
85        let mut type_tree = HashMap::default();
86        for ty in &crate_data.type_decls {
87            let long_name = repr_name(crate_data, &ty.item_meta.name);
88            if long_name.starts_with("charon_lib") {
89                let short_name = ty
90                    .item_meta
91                    .name
92                    .name
93                    .last()
94                    .unwrap()
95                    .as_ident()
96                    .unwrap()
97                    .0
98                    .clone();
99                name_to_type.insert(short_name, ty);
100            }
101            name_to_type.insert(long_name, ty);
102
103            let mut contained = HashSet::new();
104            ty.dyn_visit(|id: &TypeDeclId| {
105                contained.insert(*id);
106            });
107            type_tree.insert(ty.def_id, contained);
108        }
109
110        let mut ctx = GenerateCtx {
111            crate_data: &crate_data,
112            name_to_type,
113            type_tree,
114            manual_type_impls: Default::default(),
115            manual_json_impls: Default::default(),
116            opaque_for_visitor: Default::default(),
117        };
118        ctx.manual_type_impls = manual_type_impls
119            .iter()
120            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
121            .collect();
122        ctx.manual_json_impls = manual_json_impls
123            .iter()
124            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
125            .collect();
126        ctx.opaque_for_visitor = opaque_for_visitor
127            .iter()
128            .map(|name| ctx.id_from_name(name))
129            .collect();
130        ctx
131    }
132
133    fn id_from_name(&self, name: &str) -> TypeDeclId {
134        self.name_to_type
135            .get(name)
136            .expect(&format!("Name not found: `{name}`"))
137            .def_id
138    }
139
140    /// List the (recursive) children of this type.
141    fn children_of(&self, name: &str) -> HashSet<TypeDeclId> {
142        let start_id = self.id_from_name(name);
143        self.children_of_inner(vec![start_id])
144    }
145
146    /// List the (recursive) children of these types.
147    fn children_of_many(&self, names: &[&str]) -> HashSet<TypeDeclId> {
148        self.children_of_inner(names.iter().map(|name| self.id_from_name(name)).collect())
149    }
150
151    fn children_of_inner(&self, ty: Vec<TypeDeclId>) -> HashSet<TypeDeclId> {
152        let mut children = HashSet::new();
153        let mut stack = ty.to_vec();
154        while let Some(id) = stack.pop() {
155            if !children.contains(&id)
156                && self
157                    .crate_data
158                    .type_decls
159                    .get(id)
160                    .is_some_and(|decl| decl.item_meta.is_local)
161            {
162                children.insert(id);
163                if let Some(contained) = self.type_tree.get(&id) {
164                    stack.extend(contained);
165                }
166            }
167        }
168        children
169    }
170}
171
172/// Converts a type to the appropriate `*_of_json` call. In case of generics, this combines several
173/// functions, e.g. `list_of_json bool_of_json`.
174fn type_to_ocaml_call(ctx: &GenerateCtx, ty: &Ty) -> String {
175    match ty.kind() {
176        TyKind::Literal(LiteralTy::Bool) => "bool_of_json".to_string(),
177        TyKind::Literal(LiteralTy::Char) => "char_of_json".to_string(),
178        TyKind::Literal(LiteralTy::Int(int_ty)) => match int_ty {
179            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
180            IntTy::I128 => "big_int_of_json".to_string(),
181            _ => "int_of_json".to_string(),
182        },
183        TyKind::Literal(LiteralTy::UInt(uint_ty)) => match uint_ty {
184            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
185            UIntTy::U128 => "big_int_of_json".to_string(),
186            _ => "int_of_json".to_string(),
187        },
188        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
189        TyKind::Adt(tref) => {
190            let mut expr = Vec::new();
191            for ty in &tref.generics.types {
192                expr.push(type_to_ocaml_call(ctx, ty))
193            }
194            match tref.id {
195                TypeId::Adt(id) => {
196                    let mut first = if let Some(tdecl) = ctx.crate_data.type_decls.get(id) {
197                        type_name_to_ocaml_ident(&tdecl.item_meta)
198                    } else {
199                        format!("missing_type_{id}")
200                    };
201                    if first == "vec" {
202                        first = "list".to_string();
203                    }
204                    if first == "ustr" {
205                        first = "string".to_string();
206                    }
207                    expr.insert(0, first + "_of_json");
208                }
209                TypeId::Builtin(BuiltinTy::Box) => expr.insert(0, "box_of_json".to_owned()),
210                TypeId::Tuple => {
211                    let name = match tref.generics.types.elem_count() {
212                        2 => "pair_of_json".to_string(),
213                        3 => "triple_of_json".to_string(),
214                        len => format!("tuple_{len}_of_json"),
215                    };
216                    expr.insert(0, name);
217                }
218                _ => unimplemented!("{ty:?}"),
219            }
220            expr.into_iter().map(|f| format!("({f})")).join(" ")
221        }
222        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("arg{id}_of_json"),
223        _ => unimplemented!("{ty:?}"),
224    }
225}
226
227/// Converts a type to the appropriate ocaml name. In case of generics, this provides appropriate
228/// parameters.
229fn type_to_ocaml_name(ctx: &GenerateCtx, ty: &Ty) -> String {
230    match ty.kind() {
231        TyKind::Literal(LiteralTy::Bool) => "bool".to_string(),
232        TyKind::Literal(LiteralTy::Char) => "(Uchar.t [@visitors.opaque])".to_string(),
233        TyKind::Literal(LiteralTy::Int(int_ty)) => match int_ty {
234            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
235            IntTy::I128 => "big_int".to_string(),
236            _ => "int".to_string(),
237        },
238        TyKind::Literal(LiteralTy::UInt(uint_ty)) => match uint_ty {
239            // Even though OCaml ints are only 63 bits, only scalars with their 128 bits should be able to become too large
240            UIntTy::U128 => "big_int".to_string(),
241            _ => "int".to_string(),
242        },
243        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
244        TyKind::Adt(tref) => {
245            let mut args = tref
246                .generics
247                .types
248                .iter()
249                .map(|ty| type_to_ocaml_name(ctx, ty))
250                .map(|name| {
251                    if !name.chars().all(|c| c.is_alphanumeric()) {
252                        format!("({name})")
253                    } else {
254                        name
255                    }
256                })
257                .collect_vec();
258            match tref.id {
259                TypeId::Adt(id) => {
260                    let mut base_ty = if let Some(tdecl) = ctx.crate_data.type_decls.get(id) {
261                        type_name_to_ocaml_ident(&tdecl.item_meta)
262                    } else if let Some(name) = ctx.crate_data.item_name(id) {
263                        eprintln!(
264                            "Warning: type {} missing from llbc",
265                            repr_name(ctx.crate_data, name)
266                        );
267                        name.name
268                            .last()
269                            .unwrap()
270                            .as_ident()
271                            .unwrap()
272                            .0
273                            .to_lowercase()
274                    } else {
275                        format!("missing_type_{id}")
276                    };
277                    if base_ty == "vec" {
278                        base_ty = "list".to_string();
279                    }
280                    if base_ty == "ustr" {
281                        base_ty = "string".to_string();
282                    }
283                    if base_ty == "vector" {
284                        base_ty = "list".to_string();
285                        args.remove(0); // Remove the index generic param
286                    }
287                    let args = match args.as_slice() {
288                        [] => String::new(),
289                        [arg] => arg.clone(),
290                        args => format!("({})", args.iter().join(",")),
291                    };
292                    format!("{args} {base_ty}")
293                }
294                TypeId::Builtin(BuiltinTy::Box) => args[0].clone(),
295                TypeId::Tuple => args.iter().join("*"),
296                _ => unimplemented!("{ty:?}"),
297            }
298        }
299        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("'a{id}"),
300        _ => unimplemented!("{ty:?}"),
301    }
302}
303
304fn convert_vars<'a>(ctx: &GenerateCtx, fields: impl IntoIterator<Item = &'a Field>) -> String {
305    fields
306        .into_iter()
307        .filter(|f| !f.is_opaque())
308        .map(|f| {
309            let name = make_ocaml_ident(f.name.as_deref().unwrap());
310            let rename = make_ocaml_ident(f.renamed_name().unwrap());
311            let convert = type_to_ocaml_call(ctx, &f.ty);
312            format!("let* {rename} = {convert} ctx {name} in")
313        })
314        .join("\n")
315}
316
317fn build_branch<'a>(
318    ctx: &GenerateCtx,
319    pat: &str,
320    fields: impl IntoIterator<Item = &'a Field>,
321    construct: &str,
322) -> String {
323    let convert = convert_vars(ctx, fields);
324    format!("| {pat} -> {convert} Ok ({construct})")
325}
326
327fn build_function(_ctx: &GenerateCtx, decl: &TypeDecl, branches: &str) -> String {
328    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
329    let signature = if decl.generics.types.is_empty() {
330        format!("{ty_name}_of_json (ctx : of_json_ctx) (js : json) : ({ty_name}, string) result =")
331    } else {
332        let types = &decl.generics.types;
333        let gen_vars_space = types
334            .iter()
335            .enumerate()
336            .map(|(i, _)| format!("'a{i}"))
337            .join(" ");
338        let gen_vars_comma = types
339            .iter()
340            .enumerate()
341            .map(|(i, _)| format!("'a{i}"))
342            .join(", ");
343
344        let mut args = Vec::new();
345        let mut ty_args = Vec::new();
346        for (i, _) in types.iter().enumerate() {
347            args.push(format!("arg{i}_of_json"));
348            ty_args.push(format!("(of_json_ctx -> json -> ('a{i}, string) result)"));
349        }
350        args.push("ctx".to_string());
351        ty_args.push("of_json_ctx".to_string());
352        args.push("js".to_string());
353        ty_args.push("json".to_string());
354
355        let ty_args = ty_args.into_iter().join(" -> ");
356        let args = args.into_iter().join(" ");
357        let fun_ty =
358            format!("{gen_vars_space}. {ty_args} -> (({gen_vars_comma}) {ty_name}, string) result");
359        format!("{ty_name}_of_json : {fun_ty} = fun {args} ->")
360    };
361    format!(
362        r#"
363        and {signature}
364          combine_error_msgs js __FUNCTION__
365            (match js with{branches} | _ -> Error "")
366        "#
367    )
368}
369
370fn type_decl_to_json_deserializer(ctx: &GenerateCtx, decl: &TypeDecl) -> String {
371    let return_ty = type_name_to_ocaml_ident(&decl.item_meta);
372    let return_ty = if decl.generics.types.is_empty() {
373        return_ty
374    } else {
375        format!("_ {return_ty}")
376    };
377
378    let branches = match &decl.kind {
379        _ if let Some(def) = ctx.manual_json_impls.get(&decl.def_id) => def.clone(),
380        TypeDeclKind::Struct(fields) if fields.is_empty() => {
381            build_branch(ctx, "`Null", fields, "()")
382        }
383        TypeDeclKind::Struct(fields)
384            if fields.elem_count() == 1
385                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
386        {
387            // These are the special strongly-typed integers.
388            let short_name = decl
389                .item_meta
390                .name
391                .name
392                .last()
393                .unwrap()
394                .as_ident()
395                .unwrap()
396                .0
397                .clone();
398            format!("| x -> {short_name}.id_of_json ctx x")
399        }
400        TypeDeclKind::Struct(fields)
401            if fields.elem_count() == 1
402                && (fields[0].name.is_none()
403                    || decl
404                        .item_meta
405                        .attr_info
406                        .attributes
407                        .iter()
408                        .filter_map(|a| a.as_unknown())
409                        .any(|a| a.to_string() == "serde(transparent)")) =>
410        {
411            let ty = &fields[0].ty;
412            let call = type_to_ocaml_call(ctx, ty);
413            format!("| x -> {call} ctx x")
414        }
415        TypeDeclKind::Alias(ty) => {
416            let call = type_to_ocaml_call(ctx, ty);
417            format!("| x -> {call} ctx x")
418        }
419        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => {
420            let mut fields = fields.clone();
421            for (i, f) in fields.iter_mut().enumerate() {
422                f.name = Some(format!("x{i}"));
423            }
424            let pat: String = fields
425                .iter()
426                .map(|f| f.name.as_deref().unwrap())
427                .map(|n| make_ocaml_ident(n))
428                .join(";");
429            let pat = format!("`List [ {pat} ]");
430            let construct = fields
431                .iter()
432                .map(|f| f.renamed_name().unwrap())
433                .map(|n| make_ocaml_ident(n))
434                .join(", ");
435            let construct = format!("( {construct} )");
436            build_branch(ctx, &pat, &fields, &construct)
437        }
438        TypeDeclKind::Struct(fields) => {
439            let fields = fields
440                .iter()
441                .filter(|field| {
442                    !field
443                        .attr_info
444                        .attributes
445                        .iter()
446                        .filter_map(|a| a.as_unknown())
447                        .any(|a| a.to_string() == "serde(skip)")
448                })
449                .collect_vec();
450            let pat: String = fields
451                .iter()
452                .map(|f| {
453                    let name = f.name.as_ref().unwrap();
454                    let var = if f.is_opaque() {
455                        "_"
456                    } else {
457                        &make_ocaml_ident(name)
458                    };
459                    format!("(\"{name}\", {var});")
460                })
461                .join("\n");
462            let pat = format!("`Assoc [ {pat} ]");
463            let construct = fields
464                .iter()
465                .filter(|f| !f.is_opaque())
466                .map(|f| f.renamed_name().unwrap())
467                .map(|n| make_ocaml_ident(n))
468                .join("; ");
469            let construct = format!("({{ {construct} }} : {return_ty})");
470            build_branch(ctx, &pat, fields, &construct)
471        }
472        TypeDeclKind::Enum(variants) => {
473            variants
474                .iter()
475                .filter(|v| !v.is_opaque())
476                .map(|variant| {
477                    let name = &variant.name;
478                    let rename = variant.renamed_name();
479                    if variant.fields.is_empty() {
480                        // Unit variant
481                        let pat = format!("`String \"{name}\"");
482                        build_branch(ctx, &pat, &variant.fields, rename)
483                    } else {
484                        let mut fields = variant.fields.clone();
485                        let inner_pat = if fields.iter().all(|f| f.name.is_none()) {
486                            // Tuple variant
487                            if variant.fields.elem_count() == 1 {
488                                let var = make_ocaml_ident(&variant.name);
489                                fields[0].name = Some(var.clone());
490                                var
491                            } else {
492                                for (i, f) in fields.iter_mut().enumerate() {
493                                    f.name = Some(format!("x_{i}"));
494                                }
495                                let pat =
496                                    fields.iter().map(|f| f.name.as_ref().unwrap()).join("; ");
497                                format!("`List [ {pat} ]")
498                            }
499                        } else {
500                            // Struct variant
501                            let pat = fields
502                                .iter()
503                                .map(|f| {
504                                    let name = f.name.as_ref().unwrap();
505                                    let var = if f.is_opaque() {
506                                        "_"
507                                    } else {
508                                        &make_ocaml_ident(name)
509                                    };
510                                    format!("(\"{name}\", {var});")
511                                })
512                                .join(" ");
513                            format!("`Assoc [ {pat} ]")
514                        };
515                        let pat = format!("`Assoc [ (\"{name}\", {inner_pat}) ]");
516                        let construct_fields = fields
517                            .iter()
518                            .map(|f| f.name.as_ref().unwrap())
519                            .map(|n| make_ocaml_ident(n))
520                            .join(", ");
521                        let construct = format!("{rename} ({construct_fields})");
522                        build_branch(ctx, &pat, &fields, &construct)
523                    }
524                })
525                .join("\n")
526        }
527        TypeDeclKind::Union(..) => todo!(),
528        TypeDeclKind::Opaque => todo!(),
529        TypeDeclKind::Error(_) => todo!(),
530    };
531    build_function(ctx, decl, &branches)
532}
533
534fn extract_doc_comments(attr_info: &AttrInfo) -> String {
535    attr_info
536        .attributes
537        .iter()
538        .filter_map(|a| a.as_doc_comment())
539        .join("\n")
540}
541
542/// Make a doc comment that contains the given string, indenting it if necessary.
543fn build_doc_comment(comment: String, indent_level: usize) -> String {
544    #[derive(Default)]
545    struct Exchanger {
546        is_in_open_escaped_block: bool,
547        is_in_open_inline_escape: bool,
548    }
549    impl Exchanger {
550        pub fn exchange_escape_delimiters(&mut self, line: &str) -> String {
551            if line.contains("```") {
552                // Handle multi-line escaped blocks.
553                let (leading, mut rest) = line.split_once("```").unwrap();
554
555                // Strip all (hard-coded) possible ways to open the block.
556                rest = if rest.starts_with("text") {
557                    rest.strip_prefix("text").unwrap()
558                } else if rest.starts_with("rust,ignore") {
559                    rest.strip_prefix("rust,ignore").unwrap()
560                } else {
561                    rest
562                };
563                let mut result = leading.to_owned();
564                if self.is_in_open_escaped_block {
565                    result.push_str("]}");
566                    self.is_in_open_escaped_block = false;
567                } else {
568                    result.push_str("{@rust[");
569                    self.is_in_open_escaped_block = true;
570                }
571                result.push_str(rest);
572                result
573            } else if line.contains('`') {
574                // Handle inline escaped strings. These can occur in multiple lines, so we track them globally.
575                let mut parts = line.split('`');
576                // Skip after first part (we only need to add escapes after that).
577                let mut result = parts.next().unwrap().to_owned();
578                for part in parts {
579                    if self.is_in_open_inline_escape {
580                        result.push_str("]");
581                        result.push_str(part);
582                        self.is_in_open_inline_escape = false;
583                    } else {
584                        result.push_str("[");
585                        result.push_str(part);
586                        self.is_in_open_inline_escape = true;
587                    }
588                }
589                result
590            } else {
591                line.to_owned()
592            }
593        }
594    }
595
596    if comment == "" {
597        return comment;
598    }
599    let is_multiline = comment.contains("\n");
600    if !is_multiline {
601        let fixed_comment = Exchanger::default().exchange_escape_delimiters(&comment);
602        format!("(**{fixed_comment} *)")
603    } else {
604        let indent = "  ".repeat(indent_level);
605        let mut exchanger = Exchanger::default();
606        let comment = comment
607            .lines()
608            .enumerate()
609            .map(|(i, line)| {
610                // Remove one leading space if there is one (there usually is because we write `///
611                // comment` and not `///comment`).
612                let line = line.strip_prefix(" ").unwrap_or(line);
613                let fixed_line = exchanger.exchange_escape_delimiters(line);
614
615                // The first line follows the `(**` marker, it does not need to be indented.
616                // Neither do empty lines.
617                if i == 0 || fixed_line.is_empty() {
618                    fixed_line
619                } else {
620                    format!("{indent}    {fixed_line}")
621                }
622            })
623            .join("\n");
624        format!("(** {comment}\n{indent} *)")
625    }
626}
627
628fn build_type(_ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool, body: &str) -> String {
629    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
630    let generics = decl
631        .generics
632        .types
633        .iter()
634        .enumerate()
635        .map(|(i, _)| format!("'a{i}"))
636        .collect_vec();
637    let generics = match generics.as_slice() {
638        [] => String::new(),
639        [ty] => ty.clone(),
640        generics => format!("({})", generics.iter().join(",")),
641    };
642    let comment = extract_doc_comments(&decl.item_meta.attr_info);
643    let comment = build_doc_comment(comment, 0);
644    let keyword = if co_rec { "and" } else { "type" };
645    format!("\n{comment} {keyword} {generics} {ty_name} = {body}")
646}
647
648/// Generate an ocaml type declaration that mirrors `decl`.
649///
650/// `co_rec` indicates whether this definition is co-recursive with the ones that come before (i.e.
651/// should be declared with `and` instead of `type`).
652fn type_decl_to_ocaml_decl(ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool) -> String {
653    let opaque = if ctx.opaque_for_visitor.contains(&decl.def_id) {
654        "[@visitors.opaque]"
655    } else {
656        ""
657    };
658    let body = match &decl.kind {
659        _ if let Some(def) = ctx.manual_type_impls.get(&decl.def_id) => def.clone(),
660        TypeDeclKind::Alias(ty) => {
661            let ty = type_to_ocaml_name(ctx, ty);
662            format!("{ty} {opaque}")
663        }
664        TypeDeclKind::Struct(fields) if fields.is_empty() => "unit".to_string(),
665        TypeDeclKind::Struct(fields)
666            if fields.elem_count() == 1
667                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
668        {
669            // These are the special strongly-typed integers.
670            let short_name = decl
671                .item_meta
672                .name
673                .name
674                .last()
675                .unwrap()
676                .as_ident()
677                .unwrap()
678                .0
679                .clone();
680            format!("{short_name}.id [@visitors.opaque]")
681        }
682        TypeDeclKind::Struct(fields)
683            if fields.elem_count() == 1
684                && (fields[0].name.is_none()
685                    || decl
686                        .item_meta
687                        .attr_info
688                        .attributes
689                        .iter()
690                        .filter_map(|a| a.as_unknown())
691                        .any(|a| a.to_string() == "serde(transparent)")) =>
692        {
693            let ty = type_to_ocaml_name(ctx, &fields[0].ty);
694            format!("{ty} {opaque}")
695        }
696        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => fields
697            .iter()
698            .filter(|f| !f.is_opaque())
699            .map(|f| {
700                let ty = type_to_ocaml_name(ctx, &f.ty);
701                format!("{ty} {opaque}")
702            })
703            .join("*"),
704        TypeDeclKind::Struct(fields) => {
705            let fields = fields
706                .iter()
707                .filter(|f| !f.is_opaque())
708                .map(|f| {
709                    let name = f.renamed_name().unwrap();
710                    let ty = type_to_ocaml_name(ctx, &f.ty);
711                    let comment = extract_doc_comments(&f.attr_info);
712                    let comment = build_doc_comment(comment, 2);
713                    format!("{name} : {ty} {opaque} {comment}")
714                })
715                .join(";");
716            format!("{{ {fields} }}")
717        }
718        TypeDeclKind::Enum(variants) => {
719            variants
720                .iter()
721                .filter(|v| !v.is_opaque())
722                .map(|variant| {
723                    let mut attr_info = variant.attr_info.clone();
724                    let rename = variant.renamed_name();
725                    let ty = if variant.fields.is_empty() {
726                        // Unit variant
727                        String::new()
728                    } else {
729                        if variant.fields.iter().all(|f| f.name.is_some()) {
730                            let fields = variant
731                                .fields
732                                .iter()
733                                .map(|f| {
734                                    let comment = extract_doc_comments(&f.attr_info);
735                                    let description = if comment.is_empty() {
736                                        comment
737                                    } else {
738                                        format!(": {comment}")
739                                    };
740                                    format!("\n - [{}]{description}", f.name.as_ref().unwrap())
741                                })
742                                .join("");
743                            let field_descriptions = format!("\n Fields:{fields}");
744                            // Add a constructed doc-comment
745                            attr_info
746                                .attributes
747                                .push(Attribute::DocComment(field_descriptions));
748                        }
749                        let fields = variant
750                            .fields
751                            .iter()
752                            .map(|f| {
753                                let ty = type_to_ocaml_name(ctx, &f.ty);
754                                format!("{ty} {opaque}")
755                            })
756                            .join("*");
757                        format!(" of {fields}")
758                    };
759                    let comment = extract_doc_comments(&attr_info);
760                    let comment = build_doc_comment(comment, 3);
761                    format!("\n\n | {rename}{ty} {comment}")
762                })
763                .join("")
764        }
765        TypeDeclKind::Union(..) => todo!(),
766        TypeDeclKind::Opaque => todo!(),
767        TypeDeclKind::Error(_) => todo!(),
768    };
769    build_type(ctx, decl, co_rec, &body)
770}
771
772fn generate_visitor_bases(
773    _ctx: &GenerateCtx,
774    name: &str,
775    inherits: &[&str],
776    reduce: bool,
777    ty_names: &[String],
778) -> String {
779    let mut out = String::new();
780    let make_inherit = |variety| {
781        if !inherits.is_empty() {
782            inherits
783                .iter()
784                .map(|ancestor| {
785                    if let [module, name] = ancestor.split(".").collect_vec().as_slice() {
786                        format!("inherit [_] {module}.{variety}_{name}")
787                    } else {
788                        format!("inherit [_] {variety}_{ancestor}")
789                    }
790                })
791                .join("\n")
792        } else {
793            format!("inherit [_] VisitorsRuntime.{variety}")
794        }
795    };
796
797    let iter_methods = ty_names
798        .iter()
799        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> unit = fun _ _ -> ()"))
800        .format("\n");
801    let _ = write!(
802        &mut out,
803        "
804        class ['self] iter_{name} =
805          object (self : 'self)
806            {}
807            {iter_methods}
808          end
809        ",
810        make_inherit("iter")
811    );
812
813    let map_methods = ty_names
814        .iter()
815        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> {ty} = fun _ x -> x"))
816        .format("\n");
817    let _ = write!(
818        &mut out,
819        "
820        class ['self] map_{name} =
821          object (self : 'self)
822            {}
823            {map_methods}
824          end
825        ",
826        make_inherit("map")
827    );
828
829    if reduce {
830        let reduce_methods = ty_names
831            .iter()
832            .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> 'a = fun _ _ -> self#zero"))
833            .format("\n");
834        let _ = write!(
835            &mut out,
836            "
837            class virtual ['self] reduce_{name} =
838              object (self : 'self)
839                {}
840                {reduce_methods}
841              end
842            ",
843            make_inherit("reduce")
844        );
845
846        let mapreduce_methods = ty_names
847            .iter()
848            .map(|ty| {
849                format!("method visit_{ty} : 'env -> {ty} -> {ty} * 'a = fun _ x -> (x, self#zero)")
850            })
851            .format("\n");
852        let _ = write!(
853            &mut out,
854            "
855            class virtual ['self] mapreduce_{name} =
856              object (self : 'self)
857                {}
858                {mapreduce_methods}
859              end
860            ",
861            make_inherit("mapreduce")
862        );
863    }
864
865    out
866}
867
868#[derive(Clone, Copy)]
869struct DeriveVisitors {
870    name: &'static str,
871    ancestors: &'static [&'static str],
872    reduce: bool,
873    extra_types: &'static [&'static str],
874}
875
876/// The kind of code generation to perform.
877#[derive(Clone, Copy)]
878enum GenerationKind {
879    OfJson,
880    TypeDecl(Option<DeriveVisitors>),
881}
882
883/// Replace markers in `template` with auto-generated code.
884struct GenerateCodeFor {
885    template: PathBuf,
886    target: PathBuf,
887    /// Each list corresponds to a marker. We replace the ith `__REPLACE{i}__` marker with
888    /// generated code for each definition in the ith list.
889    ///
890    /// Eventually we should reorder definitions so the generated ones are all in one block.
891    /// Keeping the order is important while we migrate away from hand-written code.
892    markers: Vec<(GenerationKind, HashSet<TypeDeclId>)>,
893}
894
895impl GenerateCodeFor {
896    fn generate(&self, ctx: &GenerateCtx) -> Result<()> {
897        let mut template = fs::read_to_string(&self.template)
898            .with_context(|| format!("Failed to read template file {}", self.template.display()))?;
899        for (i, (kind, names)) in self.markers.iter().enumerate() {
900            let tys = names
901                .iter()
902                .map(|&id| &ctx.crate_data[id])
903                .sorted_by_key(|tdecl| {
904                    tdecl
905                        .item_meta
906                        .name
907                        .name
908                        .last()
909                        .unwrap()
910                        .as_ident()
911                        .unwrap()
912                });
913            let generated = match kind {
914                GenerationKind::OfJson => {
915                    let fns = tys
916                        .map(|ty| type_decl_to_json_deserializer(ctx, ty))
917                        .format("\n");
918                    format!("let rec ___ = ()\n{fns}")
919                }
920                GenerationKind::TypeDecl(visitors) => {
921                    let mut decls = tys
922                        .enumerate()
923                        .map(|(i, ty)| {
924                            let co_recursive = i != 0;
925                            type_decl_to_ocaml_decl(ctx, ty, co_recursive)
926                        })
927                        .join("\n");
928                    if let Some(visitors) = visitors {
929                        let &DeriveVisitors {
930                            name,
931                            mut ancestors,
932                            reduce,
933                            extra_types,
934                        } = visitors;
935                        let varieties: &[_] = if reduce {
936                            &["iter", "map", "reduce", "mapreduce"]
937                        } else {
938                            &["iter", "map"]
939                        };
940                        let intermediate_visitor_name;
941                        let intermediate_visitor_name_slice;
942                        if !extra_types.is_empty() {
943                            intermediate_visitor_name = format!("{name}_base");
944                            let intermediate_visitor = generate_visitor_bases(
945                                ctx,
946                                &intermediate_visitor_name,
947                                ancestors,
948                                reduce,
949                                extra_types
950                                    .iter()
951                                    .map(|s| s.to_string())
952                                    .collect_vec()
953                                    .as_slice(),
954                            );
955                            intermediate_visitor_name_slice = [intermediate_visitor_name.as_str()];
956                            ancestors = &intermediate_visitor_name_slice;
957                            decls = format!(
958                                "(* Ancestors for the {name} visitors *){intermediate_visitor}\n{decls}"
959                            );
960                        }
961                        let visitors = varieties
962                            .iter()
963                            .map(|variety| {
964                                let nude = if !ancestors.is_empty() {
965                                    format!("nude = true (* Don't inherit VisitorsRuntime *);")
966                                } else {
967                                    String::new()
968                                };
969                                let ancestors = format!(
970                                    "ancestors = [ {} ];",
971                                    ancestors
972                                        .iter()
973                                        .map(|a| format!("\"{variety}_{a}\""))
974                                        .join(";")
975                                );
976                                format!(
977                                    r#"
978                                    visitors {{
979                                        name = "{variety}_{name}";
980                                        monomorphic = ["env"];
981                                        variety = "{variety}";
982                                        {ancestors}
983                                        {nude}
984                                    }}
985                                "#
986                                )
987                            })
988                            .format(", ");
989                        let _ = write!(&mut decls, "\n[@@deriving show, eq, ord, {visitors}]");
990                    };
991                    decls
992                }
993            };
994            let placeholder = format!("(* __REPLACE{i}__ *)");
995            template = template.replace(&placeholder, &generated);
996        }
997
998        fs::write(&self.target, template)
999            .with_context(|| format!("Failed to write generated file {}", self.target.display()))?;
1000        Ok(())
1001    }
1002}
1003
1004fn main() -> Result<()> {
1005    let dir = PathBuf::from("src/bin/generate-ml");
1006    let charon_llbc = dir.join("charon-itself.ullbc");
1007    let reuse_llbc = std::env::var("CHARON_ML_REUSE_LLBC").is_ok(); // Useful when developping
1008    if !reuse_llbc {
1009        // Call charon on itself
1010        let mut cmd = Command::cargo_bin("charon")?;
1011        cmd.arg("cargo");
1012        cmd.arg("--hide-marker-traits");
1013        cmd.arg("--hide-allocator");
1014        cmd.arg("--ullbc");
1015        cmd.arg("--start-from=charon_lib::ast::krate::TranslatedCrate");
1016        cmd.arg("--start-from=charon_lib::ast::ullbc_ast::BodyContents");
1017        cmd.arg("--exclude=charon_lib::common::hash_consing::HashConsed");
1018        cmd.arg("--dest-file");
1019        cmd.arg(&charon_llbc);
1020        cmd.arg("--");
1021        cmd.arg("--lib");
1022        let output = cmd.output()?;
1023
1024        if !output.status.success() {
1025            let stderr = String::from_utf8(output.stderr.clone())?;
1026            bail!("Compilation failed: {stderr}")
1027        }
1028    }
1029
1030    let crate_data: TranslatedCrate = charon_lib::deserialize_llbc(&charon_llbc)?;
1031    let output_dir = if std::env::var("IN_CI").as_deref() == Ok("1") {
1032        dir.join("generated")
1033    } else {
1034        dir.join("../../../../charon-ml/src/generated")
1035    };
1036    generate_ml(crate_data, dir.join("templates"), output_dir)
1037}
1038
1039fn generate_ml(
1040    crate_data: TranslatedCrate,
1041    template_dir: PathBuf,
1042    output_dir: PathBuf,
1043) -> anyhow::Result<()> {
1044    let manual_type_impls = &[
1045        // Hand-written because we replace the `FileId` with the corresponding file.
1046        ("FileId", "file"),
1047        // Handwritten because we use `indexed_var` as a hack to be able to reuse field names.
1048        // TODO: remove the need for this hack.
1049        ("RegionParam", "(region_id, string option) indexed_var"),
1050        ("TypeParam", "(type_var_id, string) indexed_var"),
1051    ];
1052    let manual_json_impls = &[
1053        // Hand-written because we filter out `None` values.
1054        (
1055            "Vector",
1056            indoc!(
1057                r#"
1058                | js ->
1059                    let* list = list_of_json (option_of_json arg1_of_json) ctx js in
1060                    Ok (List.filter_map (fun x -> x) list)
1061                "#
1062            ),
1063        ),
1064        // Hand-written because we replace the `FileId` with the corresponding file name.
1065        (
1066            "FileId",
1067            indoc!(
1068                r#"
1069                | json ->
1070                    let* file_id = FileId.id_of_json ctx json in
1071                    let file = FileId.Map.find file_id ctx in
1072                    Ok file
1073                "#,
1074            ),
1075        ),
1076    ];
1077    // Types for which we don't want to generate a type at all.
1078    let dont_generate_ty = &[
1079        "ItemOpacity",
1080        "PredicateOrigin",
1081        "TraitTypeConstraintId",
1082        "Ty",
1083        "Vector",
1084    ];
1085    // Types that we don't want visitors to go into.
1086    let opaque_for_visitor = &["Name"];
1087    let ctx = GenerateCtx::new(
1088        &crate_data,
1089        manual_type_impls,
1090        manual_json_impls,
1091        opaque_for_visitor,
1092    );
1093
1094    // Compute the sets of types to be put in each module.
1095    let manually_implemented: HashSet<_> = [
1096        "ItemOpacity",
1097        "PredicateOrigin",
1098        "Ty", // We exclude it since `TyKind` is renamed to `ty`
1099        "Body",
1100        "FunDecl",
1101        "TranslatedCrate",
1102    ]
1103    .iter()
1104    .map(|name| ctx.id_from_name(name))
1105    .collect();
1106
1107    // Compute type sets for json deserializers.
1108    let (gast_types, llbc_types, ullbc_types) = {
1109        let llbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::llbc_ast::Statement");
1110        let ullbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::ullbc_ast::BodyContents");
1111        let all_types: HashSet<_> = ctx.children_of("TranslatedCrate");
1112
1113        let shared_types: HashSet<_> = llbc_types.intersection(&ullbc_types).copied().collect();
1114        let llbc_types: HashSet<_> = llbc_types.difference(&shared_types).copied().collect();
1115        let ullbc_types: HashSet<_> = ullbc_types.difference(&shared_types).copied().collect();
1116
1117        let body_specific_types: HashSet<_> = llbc_types.union(&ullbc_types).copied().collect();
1118        let gast_types: HashSet<_> = all_types
1119            .difference(&body_specific_types)
1120            .copied()
1121            .collect();
1122
1123        let gast_types: HashSet<_> = gast_types
1124            .difference(&manually_implemented)
1125            .copied()
1126            .collect();
1127        let llbc_types: HashSet<_> = llbc_types
1128            .difference(&manually_implemented)
1129            .copied()
1130            .collect();
1131        let ullbc_types: HashSet<_> = ullbc_types
1132            .difference(&manually_implemented)
1133            .copied()
1134            .collect();
1135        (gast_types, llbc_types, ullbc_types)
1136    };
1137
1138    let mut processed_tys: HashSet<TypeDeclId> = dont_generate_ty
1139        .iter()
1140        .map(|name| ctx.id_from_name(name))
1141        .collect();
1142    // Each call to this will return the children of the listed types that haven't been returned
1143    // yet. By calling it in dependency order, this allows to organize types into files without
1144    // having to list them all.
1145    let mut markers_from_children = |ctx: &GenerateCtx, markers: &[_]| {
1146        markers
1147            .iter()
1148            .copied()
1149            .map(|(kind, type_names)| {
1150                let types: HashSet<_> = ctx.children_of_many(type_names);
1151                let unprocessed_types: HashSet<_> =
1152                    types.difference(&processed_tys).copied().collect();
1153                processed_tys.extend(unprocessed_types.iter().copied());
1154                (kind, unprocessed_types)
1155            })
1156            .collect()
1157    };
1158
1159    #[rustfmt::skip]
1160    let generate_code_for = vec![
1161        GenerateCodeFor {
1162            template: template_dir.join("Meta.ml"),
1163            target: output_dir.join("Generated_Meta.ml"),
1164            markers: markers_from_children(&ctx, &[
1165                (GenerationKind::TypeDecl(None), &[
1166                    "File",
1167                    "Span",
1168                    "AttrInfo",
1169                ]),
1170            ]),
1171        },
1172        GenerateCodeFor {
1173            template: template_dir.join("Values.ml"),
1174            target: output_dir.join("Generated_Values.ml"),
1175            markers: markers_from_children(&ctx, &[
1176                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1177                    ancestors: &["big_int"],
1178                    name: "literal",
1179                    reduce: true,
1180                    extra_types: &[],
1181                })), &[
1182                    "Literal",
1183                    "IntegerTy",
1184                    "LiteralTy",
1185                ]),
1186            ]),
1187        },
1188        GenerateCodeFor {
1189            template: template_dir.join("Types.ml"),
1190            target: output_dir.join("Generated_Types.ml"),
1191            markers: markers_from_children(&ctx, &[
1192                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1193                    ancestors: &["literal"],
1194                    name: "const_generic",
1195                    reduce: true,
1196                    extra_types: &[],
1197                })), &[
1198                    "TypeVarId",
1199                    "ConstGeneric",
1200                    "TraitClauseId",
1201                    "DeBruijnVar",
1202                    "ItemId",
1203                ]),
1204                // Can't merge into above because aeneas uses the above alongside their own partial
1205                // copy of `ty`, which causes method type clashes.
1206                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1207                    ancestors: &["ty_base_base"],
1208                    name: "ty",
1209                    reduce: false,
1210                    extra_types: &["span"],
1211                })), &[
1212                    "TyKind",
1213                    "TraitImplRef",
1214                    "FunDeclRef",
1215                    "GlobalDeclRef",
1216                ]),
1217                // TODO: can't merge into above because of field name clashes (`types`, `regions` etc).
1218                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1219                    ancestors: &["ty"],
1220                    name: "type_decl",
1221                    reduce: false,
1222                    extra_types: &[
1223                        "attr_info"
1224                    ],
1225                })), &[
1226                    "Binder",
1227                    "AbortKind",
1228                    "TypeDecl",
1229                ]),
1230            ]),
1231        },
1232        GenerateCodeFor {
1233            template: template_dir.join("Expressions.ml"),
1234            target: output_dir.join("Generated_Expressions.ml"),
1235            markers: markers_from_children(&ctx, &[
1236                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1237                    ancestors: &["type_decl"],
1238                    name: "rvalue",
1239                    reduce: false,
1240                    extra_types: &[],
1241                })), &[
1242                    "Rvalue",
1243                ]),
1244            ]),
1245        },
1246        GenerateCodeFor {
1247            template: template_dir.join("GAst.ml"),
1248            target: output_dir.join("Generated_GAst.ml"),
1249            markers: markers_from_children(&ctx, &[
1250                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1251                    ancestors: &["rvalue"],
1252                    name: "fun_sig",
1253                    reduce: false,
1254                    extra_types: &[],
1255                })), &[
1256                    "Call",
1257                    "Assert",
1258                    "ItemSource",
1259                    "Locals",
1260                    "FunSig",
1261                    "CopyNonOverlapping",
1262                    "Error",
1263                ]),
1264                // These have to be kept separate to avoid field name clashes
1265                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1266                    ancestors: &["fun_sig"],
1267                    name: "global_decl",
1268                    reduce: false,
1269                    extra_types: &[],
1270                })), &[
1271                    "GlobalDecl",
1272                ]),
1273                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1274                    ancestors: &["global_decl"],
1275                    name: "trait_decl",
1276                    reduce: false,
1277                    extra_types: &[],
1278                })), &[
1279                    "TraitDecl",
1280                ]),
1281                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1282                    ancestors: &["trait_decl"],
1283                    name: "trait_impl",
1284                    reduce: false,
1285                    extra_types: &[],
1286                })), &[
1287                    "TraitImpl",
1288                ]),
1289                (GenerationKind::TypeDecl(None), &[
1290                    "CliOpts",
1291                    "GExprBody",
1292                    "DeclarationGroup",
1293                ]),
1294            ]),
1295        },
1296        GenerateCodeFor {
1297            template: template_dir.join("LlbcAst.ml"),
1298            target: output_dir.join("Generated_LlbcAst.ml"),
1299            markers: markers_from_children(&ctx, &[
1300                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1301                    name: "statement_base",
1302                    ancestors: &["trait_impl"],
1303                    reduce: false,
1304                    extra_types: &[],
1305                })), &[
1306                    "charon_lib::ast::llbc_ast::Statement",
1307                ]),
1308            ]),
1309        },
1310        GenerateCodeFor {
1311            template: template_dir.join("UllbcAst.ml"),
1312            target: output_dir.join("Generated_UllbcAst.ml"),
1313            markers: markers_from_children(&ctx, &[
1314                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1315                    ancestors: &["trait_impl"],
1316                    name: "statement",
1317                    reduce: false,
1318                    extra_types: &[],
1319                })), &[
1320                    "charon_lib::ast::ullbc_ast::Statement",
1321                    "charon_lib::ast::ullbc_ast::SwitchTargets",
1322                ]),
1323                // TODO: Can't merge with above because of field name clashes (`content` and `span`).
1324                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1325                    ancestors: &["statement"],
1326                    name: "ullbc_ast",
1327                    reduce: false,
1328                    extra_types: &[],
1329                })), &[
1330                    "charon_lib::ast::ullbc_ast::BodyContents",
1331                ]),
1332            ]),
1333        },
1334        GenerateCodeFor {
1335            template: template_dir.join("GAstOfJson.ml"),
1336            target: output_dir.join("Generated_GAstOfJson.ml"),
1337            markers: vec![(GenerationKind::OfJson, gast_types)],
1338        },
1339        GenerateCodeFor {
1340            template: template_dir.join("LlbcOfJson.ml"),
1341            target: output_dir.join("Generated_LlbcOfJson.ml"),
1342            markers: vec![(GenerationKind::OfJson, llbc_types)],
1343        },
1344        GenerateCodeFor {
1345            template: template_dir.join("UllbcOfJson.ml"),
1346            target: output_dir.join("Generated_UllbcOfJson.ml"),
1347            markers: vec![(GenerationKind::OfJson, ullbc_types)],
1348        },
1349    ];
1350    for file in generate_code_for {
1351        file.generate(&ctx)?;
1352    }
1353    Ok(())
1354}