generate_ml/
main.rs

1//! Generate ocaml deserialization code for our types.
2//!
3//! This binary runs charon on itself and generates the appropriate `<type>_of_json` functions for
4//! our types. The generated functions are inserted into `./generate-ml/GAstOfJson.template.ml` to
5//! construct the final `GAstOfJson.ml`.
6//!
7//! To run it, call `cargo run --bin generate-ml`. It is also run by `make generate-ml` in the
8//! crate root. Don't forget to format the output code after regenerating.
9#![feature(if_let_guard)]
10
11use anyhow::{bail, Context, Result};
12use assert_cmd::cargo::CommandCargoExt;
13use charon_lib::ast::*;
14use convert_case::{Case, Casing};
15use indoc::indoc;
16use itertools::Itertools;
17use std::collections::{HashMap, HashSet};
18use std::fmt::Write;
19use std::fs;
20use std::path::PathBuf;
21use std::process::Command;
22
23/// `Name` is a complex datastructure; to inspect it we serialize it a little bit.
24fn repr_name(_crate_data: &TranslatedCrate, n: &Name) -> String {
25    n.name
26        .iter()
27        .map(|path_elem| match path_elem {
28            PathElem::Ident(i, _) => i.clone(),
29            PathElem::Impl(..) => "<impl>".to_string(),
30        })
31        .join("::")
32}
33
34fn make_ocaml_ident(name: &str) -> String {
35    let mut name = name.to_case(Case::Snake);
36    if matches!(
37        &*name,
38        "virtual"
39            | "bool"
40            | "char"
41            | "struct"
42            | "type"
43            | "let"
44            | "fun"
45            | "open"
46            | "rec"
47            | "assert"
48            | "float"
49            | "end"
50            | "include"
51    ) {
52        name += "_";
53    }
54    name
55}
56fn type_name_to_ocaml_ident(item_meta: &ItemMeta) -> String {
57    let name = item_meta
58        .attr_info
59        .rename
60        .as_ref()
61        .unwrap_or(item_meta.name.name.last().unwrap().as_ident().unwrap().0);
62    make_ocaml_ident(name)
63}
64
65struct GenerateCtx<'a> {
66    crate_data: &'a TranslatedCrate,
67    name_to_type: HashMap<String, &'a TypeDecl>,
68    /// For each type, list the types it contains.
69    type_tree: HashMap<TypeDeclId, HashSet<TypeDeclId>>,
70    manual_type_impls: HashMap<TypeDeclId, String>,
71    manual_json_impls: HashMap<TypeDeclId, String>,
72    opaque_for_visitor: HashSet<TypeDeclId>,
73}
74
75impl<'a> GenerateCtx<'a> {
76    fn new(
77        crate_data: &'a TranslatedCrate,
78        manual_type_impls: &[(&str, &str)],
79        manual_json_impls: &[(&str, &str)],
80        opaque_for_visitor: &[&str],
81    ) -> Self {
82        let mut name_to_type: HashMap<String, &TypeDecl> = Default::default();
83        let mut type_tree = HashMap::default();
84        for ty in &crate_data.type_decls {
85            let long_name = repr_name(crate_data, &ty.item_meta.name);
86            if long_name.starts_with("charon_lib") {
87                let short_name = ty
88                    .item_meta
89                    .name
90                    .name
91                    .last()
92                    .unwrap()
93                    .as_ident()
94                    .unwrap()
95                    .0
96                    .clone();
97                name_to_type.insert(short_name, ty);
98            }
99            name_to_type.insert(long_name, ty);
100
101            let mut contained = HashSet::new();
102            ty.dyn_visit(|id: &TypeDeclId| {
103                contained.insert(*id);
104            });
105            type_tree.insert(ty.def_id, contained);
106        }
107
108        let mut ctx = GenerateCtx {
109            crate_data: &crate_data,
110            name_to_type,
111            type_tree,
112            manual_type_impls: Default::default(),
113            manual_json_impls: Default::default(),
114            opaque_for_visitor: Default::default(),
115        };
116        ctx.manual_type_impls = manual_type_impls
117            .iter()
118            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
119            .collect();
120        ctx.manual_json_impls = manual_json_impls
121            .iter()
122            .map(|(name, def)| (ctx.id_from_name(name), def.to_string()))
123            .collect();
124        ctx.opaque_for_visitor = opaque_for_visitor
125            .iter()
126            .map(|name| ctx.id_from_name(name))
127            .collect();
128        ctx
129    }
130
131    fn id_from_name(&self, name: &str) -> TypeDeclId {
132        self.name_to_type
133            .get(name)
134            .expect(&format!("Name not found: `{name}`"))
135            .def_id
136    }
137
138    /// List the (recursive) children of this type.
139    fn children_of(&self, name: &str) -> HashSet<TypeDeclId> {
140        let start_id = self.id_from_name(name);
141        self.children_of_inner(vec![start_id])
142    }
143
144    /// List the (recursive) children of these types.
145    fn children_of_many(&self, names: &[&str]) -> HashSet<TypeDeclId> {
146        self.children_of_inner(names.iter().map(|name| self.id_from_name(name)).collect())
147    }
148
149    fn children_of_inner(&self, ty: Vec<TypeDeclId>) -> HashSet<TypeDeclId> {
150        let mut children = HashSet::new();
151        let mut stack = ty.to_vec();
152        while let Some(id) = stack.pop() {
153            if !children.contains(&id)
154                && self
155                    .crate_data
156                    .type_decls
157                    .get(id)
158                    .is_some_and(|decl| decl.item_meta.is_local)
159            {
160                children.insert(id);
161                if let Some(contained) = self.type_tree.get(&id) {
162                    stack.extend(contained);
163                }
164            }
165        }
166        children
167    }
168}
169
170/// Converts a type to the appropriate `*_of_json` call. In case of generics, this combines several
171/// functions, e.g. `list_of_json bool_of_json`.
172fn type_to_ocaml_call(ctx: &GenerateCtx, ty: &Ty) -> String {
173    match ty.kind() {
174        TyKind::Literal(LiteralTy::Bool) => "bool_of_json".to_string(),
175        TyKind::Literal(LiteralTy::Char) => "char_of_json".to_string(),
176        TyKind::Literal(LiteralTy::Integer(_)) => "int_of_json".to_string(),
177        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
178        TyKind::Adt(adt_kind, generics) => {
179            let mut expr = Vec::new();
180            for ty in &generics.types {
181                expr.push(type_to_ocaml_call(ctx, ty))
182            }
183            match adt_kind {
184                TypeId::Adt(id) => {
185                    let mut first = if let Some(tdecl) = ctx.crate_data.type_decls.get(*id) {
186                        type_name_to_ocaml_ident(&tdecl.item_meta)
187                    } else {
188                        format!("missing_type_{id}")
189                    };
190                    if first == "vec" {
191                        first = "list".to_string();
192                        expr.pop(); // Remove the allocator generic param
193                    }
194                    expr.insert(0, first + "_of_json");
195                }
196                TypeId::Builtin(BuiltinTy::Box) => expr.insert(0, "box_of_json".to_owned()),
197                TypeId::Tuple => {
198                    let name = match generics.types.elem_count() {
199                        2 => "pair_of_json".to_string(),
200                        3 => "triple_of_json".to_string(),
201                        len => format!("tuple_{len}_of_json"),
202                    };
203                    expr.insert(0, name);
204                }
205                _ => unimplemented!("{ty:?}"),
206            }
207            expr.into_iter().map(|f| format!("({f})")).join(" ")
208        }
209        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("arg{id}_of_json"),
210        _ => unimplemented!("{ty:?}"),
211    }
212}
213
214/// Converts a type to the appropriate ocaml name. In case of generics, this provides appropriate
215/// parameters.
216fn type_to_ocaml_name(ctx: &GenerateCtx, ty: &Ty) -> String {
217    match ty.kind() {
218        TyKind::Literal(LiteralTy::Bool) => "bool".to_string(),
219        TyKind::Literal(LiteralTy::Char) => "(Uchar.t [@visitors.opaque])".to_string(),
220        TyKind::Literal(LiteralTy::Integer(_)) => "int".to_string(),
221        TyKind::Literal(LiteralTy::Float(_)) => "float_of_json".to_string(),
222        TyKind::Adt(adt_kind, generics) => {
223            let mut args = generics
224                .types
225                .iter()
226                .map(|ty| type_to_ocaml_name(ctx, ty))
227                .map(|name| {
228                    if !name.chars().all(|c| c.is_alphanumeric()) {
229                        format!("({name})")
230                    } else {
231                        name
232                    }
233                })
234                .collect_vec();
235            match adt_kind {
236                TypeId::Adt(id) => {
237                    let mut base_ty = if let Some(tdecl) = ctx.crate_data.type_decls.get(*id) {
238                        type_name_to_ocaml_ident(&tdecl.item_meta)
239                    } else if let Some(name) = ctx.crate_data.item_name(*id) {
240                        eprintln!(
241                            "Warning: type {} missing from llbc",
242                            repr_name(ctx.crate_data, name)
243                        );
244                        name.name
245                            .last()
246                            .unwrap()
247                            .as_ident()
248                            .unwrap()
249                            .0
250                            .to_lowercase()
251                    } else {
252                        format!("missing_type_{id}")
253                    };
254                    if base_ty == "vec" {
255                        base_ty = "list".to_string();
256                        args.pop(); // Remove the allocator generic param
257                    }
258                    if base_ty == "vector" {
259                        base_ty = "list".to_string();
260                        args.remove(0); // Remove the index generic param
261                    }
262                    let args = match args.as_slice() {
263                        [] => String::new(),
264                        [arg] => arg.clone(),
265                        args => format!("({})", args.iter().join(",")),
266                    };
267                    format!("{args} {base_ty}")
268                }
269                TypeId::Builtin(BuiltinTy::Box) => args[0].clone(),
270                TypeId::Tuple => args.iter().join("*"),
271                _ => unimplemented!("{ty:?}"),
272            }
273        }
274        TyKind::TypeVar(DeBruijnVar::Free(id)) => format!("'a{id}"),
275        _ => unimplemented!("{ty:?}"),
276    }
277}
278
279fn convert_vars<'a>(ctx: &GenerateCtx, fields: impl IntoIterator<Item = &'a Field>) -> String {
280    fields
281        .into_iter()
282        .filter(|f| !f.is_opaque())
283        .map(|f| {
284            let name = make_ocaml_ident(f.name.as_deref().unwrap());
285            let rename = make_ocaml_ident(f.renamed_name().unwrap());
286            let convert = type_to_ocaml_call(ctx, &f.ty);
287            format!("let* {rename} = {convert} ctx {name} in")
288        })
289        .join("\n")
290}
291
292fn build_branch<'a>(
293    ctx: &GenerateCtx,
294    pat: &str,
295    fields: impl IntoIterator<Item = &'a Field>,
296    construct: &str,
297) -> String {
298    let convert = convert_vars(ctx, fields);
299    format!("| {pat} -> {convert} Ok ({construct})")
300}
301
302fn build_function(_ctx: &GenerateCtx, decl: &TypeDecl, branches: &str) -> String {
303    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
304    let signature = if decl.generics.types.is_empty() {
305        format!("{ty_name}_of_json (ctx : of_json_ctx) (js : json) : ({ty_name}, string) result =")
306    } else {
307        let types = &decl.generics.types;
308        let gen_vars_space = types
309            .iter()
310            .enumerate()
311            .map(|(i, _)| format!("'a{i}"))
312            .join(" ");
313        let gen_vars_comma = types
314            .iter()
315            .enumerate()
316            .map(|(i, _)| format!("'a{i}"))
317            .join(", ");
318
319        let mut args = Vec::new();
320        let mut ty_args = Vec::new();
321        for (i, _) in types.iter().enumerate() {
322            args.push(format!("arg{i}_of_json"));
323            ty_args.push(format!("(of_json_ctx -> json -> ('a{i}, string) result)"));
324        }
325        args.push("ctx".to_string());
326        ty_args.push("of_json_ctx".to_string());
327        args.push("js".to_string());
328        ty_args.push("json".to_string());
329
330        let ty_args = ty_args.into_iter().join(" -> ");
331        let args = args.into_iter().join(" ");
332        let fun_ty =
333            format!("{gen_vars_space}. {ty_args} -> (({gen_vars_comma}) {ty_name}, string) result");
334        format!("{ty_name}_of_json : {fun_ty} = fun {args} ->")
335    };
336    format!(
337        r#"
338        and {signature}
339          combine_error_msgs js __FUNCTION__
340            (match js with{branches} | _ -> Error "")
341        "#
342    )
343}
344
345fn type_decl_to_json_deserializer(ctx: &GenerateCtx, decl: &TypeDecl) -> String {
346    let return_ty = type_name_to_ocaml_ident(&decl.item_meta);
347    let return_ty = if decl.generics.types.is_empty() {
348        return_ty
349    } else {
350        format!("_ {return_ty}")
351    };
352
353    let branches = match &decl.kind {
354        _ if let Some(def) = ctx.manual_json_impls.get(&decl.def_id) => def.clone(),
355        TypeDeclKind::Struct(fields) if fields.is_empty() => {
356            build_branch(ctx, "`Null", fields, "()")
357        }
358        TypeDeclKind::Struct(fields)
359            if fields.elem_count() == 1
360                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
361        {
362            // These are the special strongly-typed integers.
363            let short_name = decl
364                .item_meta
365                .name
366                .name
367                .last()
368                .unwrap()
369                .as_ident()
370                .unwrap()
371                .0
372                .clone();
373            format!("| x -> {short_name}.id_of_json ctx x")
374        }
375        TypeDeclKind::Struct(fields)
376            if fields.elem_count() == 1
377                && (fields[0].name.is_none()
378                    || decl
379                        .item_meta
380                        .attr_info
381                        .attributes
382                        .iter()
383                        .filter_map(|a| a.as_unknown())
384                        .any(|a| a.to_string() == "serde(transparent)")) =>
385        {
386            let ty = &fields[0].ty;
387            let call = type_to_ocaml_call(ctx, ty);
388            format!("| x -> {call} ctx x")
389        }
390        TypeDeclKind::Alias(ty) => {
391            let call = type_to_ocaml_call(ctx, ty);
392            format!("| x -> {call} ctx x")
393        }
394        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => {
395            let mut fields = fields.clone();
396            for (i, f) in fields.iter_mut().enumerate() {
397                f.name = Some(format!("x{i}"));
398            }
399            let pat: String = fields
400                .iter()
401                .map(|f| f.name.as_deref().unwrap())
402                .map(|n| make_ocaml_ident(n))
403                .join(";");
404            let pat = format!("`List [ {pat} ]");
405            let construct = fields
406                .iter()
407                .map(|f| f.renamed_name().unwrap())
408                .map(|n| make_ocaml_ident(n))
409                .join(", ");
410            let construct = format!("( {construct} )");
411            build_branch(ctx, &pat, &fields, &construct)
412        }
413        TypeDeclKind::Struct(fields) => {
414            let fields = fields
415                .iter()
416                .filter(|field| {
417                    !field
418                        .attr_info
419                        .attributes
420                        .iter()
421                        .filter_map(|a| a.as_unknown())
422                        .any(|a| a.to_string() == "serde(skip)")
423                })
424                .collect_vec();
425            let pat: String = fields
426                .iter()
427                .map(|f| {
428                    let name = f.name.as_ref().unwrap();
429                    let var = if f.is_opaque() {
430                        "_"
431                    } else {
432                        &make_ocaml_ident(name)
433                    };
434                    format!("(\"{name}\", {var});")
435                })
436                .join("\n");
437            let pat = format!("`Assoc [ {pat} ]");
438            let construct = fields
439                .iter()
440                .filter(|f| !f.is_opaque())
441                .map(|f| f.renamed_name().unwrap())
442                .map(|n| make_ocaml_ident(n))
443                .join("; ");
444            let construct = format!("({{ {construct} }} : {return_ty})");
445            build_branch(ctx, &pat, fields, &construct)
446        }
447        TypeDeclKind::Enum(variants) => {
448            variants
449                .iter()
450                .filter(|v| !v.is_opaque())
451                .map(|variant| {
452                    let name = &variant.name;
453                    let rename = variant.renamed_name();
454                    if variant.fields.is_empty() {
455                        // Unit variant
456                        let pat = format!("`String \"{name}\"");
457                        build_branch(ctx, &pat, &variant.fields, rename)
458                    } else {
459                        let mut fields = variant.fields.clone();
460                        let inner_pat = if fields.iter().all(|f| f.name.is_none()) {
461                            // Tuple variant
462                            if variant.fields.elem_count() == 1 {
463                                let var = make_ocaml_ident(&variant.name);
464                                fields[0].name = Some(var.clone());
465                                var
466                            } else {
467                                for (i, f) in fields.iter_mut().enumerate() {
468                                    f.name = Some(format!("x_{i}"));
469                                }
470                                let pat =
471                                    fields.iter().map(|f| f.name.as_ref().unwrap()).join("; ");
472                                format!("`List [ {pat} ]")
473                            }
474                        } else {
475                            // Struct variant
476                            let pat = fields
477                                .iter()
478                                .map(|f| {
479                                    let name = f.name.as_ref().unwrap();
480                                    let var = if f.is_opaque() {
481                                        "_"
482                                    } else {
483                                        &make_ocaml_ident(name)
484                                    };
485                                    format!("(\"{name}\", {var});")
486                                })
487                                .join(" ");
488                            format!("`Assoc [ {pat} ]")
489                        };
490                        let pat = format!("`Assoc [ (\"{name}\", {inner_pat}) ]");
491                        let construct_fields = fields
492                            .iter()
493                            .map(|f| f.name.as_ref().unwrap())
494                            .map(|n| make_ocaml_ident(n))
495                            .join(", ");
496                        let construct = format!("{rename} ({construct_fields})");
497                        build_branch(ctx, &pat, &fields, &construct)
498                    }
499                })
500                .join("\n")
501        }
502        TypeDeclKind::Union(..) => todo!(),
503        TypeDeclKind::Opaque => todo!(),
504        TypeDeclKind::Error(_) => todo!(),
505    };
506    build_function(ctx, decl, &branches)
507}
508
509fn extract_doc_comments(attr_info: &AttrInfo) -> String {
510    attr_info
511        .attributes
512        .iter()
513        .filter_map(|a| a.as_doc_comment())
514        .join("\n")
515}
516
517/// Make a doc comment that contains the given string, indenting it if necessary.
518fn build_doc_comment(comment: String, indent_level: usize) -> String {
519    if comment == "" {
520        return comment;
521    }
522    let is_multiline = comment.contains("\n");
523    if !is_multiline {
524        format!("(**{comment} *)")
525    } else {
526        let indent = "  ".repeat(indent_level);
527        let comment = comment
528            .lines()
529            .enumerate()
530            .map(|(i, line)| {
531                // Remove one leading space if there is one (there usually is because we write `///
532                // comment` and not `///comment`).
533                let line = line.strip_prefix(" ").unwrap_or(line);
534                // The first line follows the `(**` marker, it does not need to be indented.
535                // Neither do empty lines.
536                if i == 0 || line.is_empty() {
537                    line.to_string()
538                } else {
539                    format!("{indent}    {line}")
540                }
541            })
542            .join("\n");
543        format!("(** {comment}\n{indent} *)")
544    }
545}
546
547fn build_type(_ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool, body: &str) -> String {
548    let ty_name = type_name_to_ocaml_ident(&decl.item_meta);
549    let generics = decl
550        .generics
551        .types
552        .iter()
553        .enumerate()
554        .map(|(i, _)| format!("'a{i}"))
555        .collect_vec();
556    let generics = match generics.as_slice() {
557        [] => String::new(),
558        [ty] => ty.clone(),
559        generics => format!("({})", generics.iter().join(",")),
560    };
561    let comment = extract_doc_comments(&decl.item_meta.attr_info);
562    let comment = build_doc_comment(comment, 0);
563    let keyword = if co_rec { "and" } else { "type" };
564    format!("\n{comment} {keyword} {generics} {ty_name} = {body}")
565}
566
567/// Generate an ocaml type declaration that mirrors `decl`.
568///
569/// `co_rec` indicates whether this definition is co-recursive with the ones that come before (i.e.
570/// should be declared with `and` instead of `type`).
571fn type_decl_to_ocaml_decl(ctx: &GenerateCtx, decl: &TypeDecl, co_rec: bool) -> String {
572    let opaque = if ctx.opaque_for_visitor.contains(&decl.def_id) {
573        "[@visitors.opaque]"
574    } else {
575        ""
576    };
577    let body = match &decl.kind {
578        _ if let Some(def) = ctx.manual_type_impls.get(&decl.def_id) => def.clone(),
579        TypeDeclKind::Alias(ty) => {
580            let ty = type_to_ocaml_name(ctx, ty);
581            format!("{ty} {opaque}")
582        }
583        TypeDeclKind::Struct(fields) if fields.is_empty() => "unit".to_string(),
584        TypeDeclKind::Struct(fields)
585            if fields.elem_count() == 1
586                && fields[0].name.as_ref().is_some_and(|name| name == "_raw") =>
587        {
588            // These are the special strongly-typed integers.
589            let short_name = decl
590                .item_meta
591                .name
592                .name
593                .last()
594                .unwrap()
595                .as_ident()
596                .unwrap()
597                .0
598                .clone();
599            format!("{short_name}.id [@visitors.opaque]")
600        }
601        TypeDeclKind::Struct(fields)
602            if fields.elem_count() == 1
603                && (fields[0].name.is_none()
604                    || decl
605                        .item_meta
606                        .attr_info
607                        .attributes
608                        .iter()
609                        .filter_map(|a| a.as_unknown())
610                        .any(|a| a.to_string() == "serde(transparent)")) =>
611        {
612            let ty = type_to_ocaml_name(ctx, &fields[0].ty);
613            format!("{ty} {opaque}")
614        }
615        TypeDeclKind::Struct(fields) if fields.iter().all(|f| f.name.is_none()) => fields
616            .iter()
617            .filter(|f| !f.is_opaque())
618            .map(|f| {
619                let ty = type_to_ocaml_name(ctx, &f.ty);
620                format!("{ty} {opaque}")
621            })
622            .join("*"),
623        TypeDeclKind::Struct(fields) => {
624            let fields = fields
625                .iter()
626                .filter(|f| !f.is_opaque())
627                .map(|f| {
628                    let name = f.renamed_name().unwrap();
629                    let ty = type_to_ocaml_name(ctx, &f.ty);
630                    let comment = extract_doc_comments(&f.attr_info);
631                    let comment = build_doc_comment(comment, 2);
632                    format!("{name} : {ty} {opaque} {comment}")
633                })
634                .join(";");
635            format!("{{ {fields} }}")
636        }
637        TypeDeclKind::Enum(variants) => {
638            variants
639                .iter()
640                .filter(|v| !v.is_opaque())
641                .map(|variant| {
642                    let mut attr_info = variant.attr_info.clone();
643                    let rename = variant.renamed_name();
644                    let ty = if variant.fields.is_empty() {
645                        // Unit variant
646                        String::new()
647                    } else {
648                        if variant.fields.iter().all(|f| f.name.is_some()) {
649                            let fields = variant
650                                .fields
651                                .iter()
652                                .map(|f| {
653                                    let comment = extract_doc_comments(&f.attr_info);
654                                    let description = if comment.is_empty() {
655                                        comment
656                                    } else {
657                                        format!(": {comment}")
658                                    };
659                                    format!("\n - [{}]{description}", f.name.as_ref().unwrap())
660                                })
661                                .join("");
662                            let field_descriptions = format!("\n Fields:{fields}");
663                            // Add a constructed doc-comment
664                            attr_info
665                                .attributes
666                                .push(Attribute::DocComment(field_descriptions));
667                        }
668                        let fields = variant
669                            .fields
670                            .iter()
671                            .map(|f| {
672                                let ty = type_to_ocaml_name(ctx, &f.ty);
673                                format!("{ty} {opaque}")
674                            })
675                            .join("*");
676                        format!(" of {fields}")
677                    };
678                    let comment = extract_doc_comments(&attr_info);
679                    let comment = build_doc_comment(comment, 3);
680                    format!("\n\n | {rename}{ty} {comment}")
681                })
682                .join("")
683        }
684        TypeDeclKind::Union(..) => todo!(),
685        TypeDeclKind::Opaque => todo!(),
686        TypeDeclKind::Error(_) => todo!(),
687    };
688    build_type(ctx, decl, co_rec, &body)
689}
690
691fn generate_visitor_bases(
692    _ctx: &GenerateCtx,
693    name: &str,
694    inherits: &[&str],
695    reduce: bool,
696    ty_names: &[String],
697) -> String {
698    let mut out = String::new();
699    let make_inherit = |variety| {
700        if !inherits.is_empty() {
701            inherits
702                .iter()
703                .map(|ancestor| {
704                    if let [module, name] = ancestor.split(".").collect_vec().as_slice() {
705                        format!("inherit [_] {module}.{variety}_{name}")
706                    } else {
707                        format!("inherit [_] {variety}_{ancestor}")
708                    }
709                })
710                .join("\n")
711        } else {
712            format!("inherit [_] VisitorsRuntime.{variety}")
713        }
714    };
715
716    let iter_methods = ty_names
717        .iter()
718        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> unit = fun _ _ -> ()"))
719        .format("\n");
720    let _ = write!(
721        &mut out,
722        "
723        class ['self] iter_{name} =
724          object (self : 'self)
725            {}
726            {iter_methods}
727          end
728        ",
729        make_inherit("iter")
730    );
731
732    let map_methods = ty_names
733        .iter()
734        .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> {ty} = fun _ x -> x"))
735        .format("\n");
736    let _ = write!(
737        &mut out,
738        "
739        class ['self] map_{name} =
740          object (self : 'self)
741            {}
742            {map_methods}
743          end
744        ",
745        make_inherit("map")
746    );
747
748    if reduce {
749        let reduce_methods = ty_names
750            .iter()
751            .map(|ty| format!("method visit_{ty} : 'env -> {ty} -> 'a = fun _ _ -> self#zero"))
752            .format("\n");
753        let _ = write!(
754            &mut out,
755            "
756            class virtual ['self] reduce_{name} =
757              object (self : 'self)
758                {}
759                {reduce_methods}
760              end
761            ",
762            make_inherit("reduce")
763        );
764
765        let mapreduce_methods = ty_names
766            .iter()
767            .map(|ty| {
768                format!("method visit_{ty} : 'env -> {ty} -> {ty} * 'a = fun _ x -> (x, self#zero)")
769            })
770            .format("\n");
771        let _ = write!(
772            &mut out,
773            "
774            class virtual ['self] mapreduce_{name} =
775              object (self : 'self)
776                {}
777                {mapreduce_methods}
778              end
779            ",
780            make_inherit("mapreduce")
781        );
782    }
783
784    out
785}
786
787#[derive(Clone, Copy)]
788struct DeriveVisitors {
789    name: &'static str,
790    ancestors: &'static [&'static str],
791    reduce: bool,
792    extra_types: &'static [&'static str],
793}
794
795/// The kind of code generation to perform.
796#[derive(Clone, Copy)]
797enum GenerationKind {
798    OfJson,
799    TypeDecl(Option<DeriveVisitors>),
800}
801
802/// Replace markers in `template` with auto-generated code.
803struct GenerateCodeFor {
804    template: PathBuf,
805    target: PathBuf,
806    /// Each list corresponds to a marker. We replace the ith `__REPLACE{i}__` marker with
807    /// generated code for each definition in the ith list.
808    ///
809    /// Eventually we should reorder definitions so the generated ones are all in one block.
810    /// Keeping the order is important while we migrate away from hand-written code.
811    markers: Vec<(GenerationKind, HashSet<TypeDeclId>)>,
812}
813
814impl GenerateCodeFor {
815    fn generate(&self, ctx: &GenerateCtx) -> Result<()> {
816        let mut template = fs::read_to_string(&self.template)
817            .with_context(|| format!("Failed to read template file {}", self.template.display()))?;
818        for (i, (kind, names)) in self.markers.iter().enumerate() {
819            let tys = names.iter().copied().sorted().map(|id| &ctx.crate_data[id]);
820            let generated = match kind {
821                GenerationKind::OfJson => {
822                    let fns = tys
823                        .map(|ty| type_decl_to_json_deserializer(ctx, ty))
824                        .format("\n");
825                    format!("let rec ___ = ()\n{fns}")
826                }
827                GenerationKind::TypeDecl(visitors) => {
828                    let mut decls = tys
829                        .enumerate()
830                        .map(|(i, ty)| {
831                            let co_recursive = i != 0;
832                            type_decl_to_ocaml_decl(ctx, ty, co_recursive)
833                        })
834                        .join("\n");
835                    if let Some(visitors) = visitors {
836                        let DeriveVisitors {
837                            name,
838                            mut ancestors,
839                            reduce,
840                            extra_types,
841                        } = visitors;
842                        let varieties: &[_] = if *reduce {
843                            &["iter", "map", "reduce", "mapreduce"]
844                        } else {
845                            &["iter", "map"]
846                        };
847                        let intermediate_visitor_name;
848                        let intermediate_visitor_name_slice;
849                        if !extra_types.is_empty() {
850                            intermediate_visitor_name = format!("{name}_base");
851                            let intermediate_visitor = generate_visitor_bases(
852                                ctx,
853                                &intermediate_visitor_name,
854                                ancestors,
855                                *reduce,
856                                extra_types
857                                    .iter()
858                                    .map(|s| s.to_string())
859                                    .collect_vec()
860                                    .as_slice(),
861                            );
862                            intermediate_visitor_name_slice = [intermediate_visitor_name.as_str()];
863                            ancestors = &intermediate_visitor_name_slice;
864                            decls = format!("(* Ancestors for the {name} visitors *){intermediate_visitor}\n{decls}");
865                        }
866                        let visitors = varieties
867                            .iter()
868                            .map(|variety| {
869                                let nude = if !ancestors.is_empty() {
870                                    format!("nude = true (* Don't inherit VisitorsRuntime *);")
871                                } else {
872                                    String::new()
873                                };
874                                let ancestors = format!(
875                                    "ancestors = [ {} ];",
876                                    ancestors
877                                        .iter()
878                                        .map(|a| format!("\"{variety}_{a}\""))
879                                        .join(";")
880                                );
881                                format!(
882                                    r#"
883                                    visitors {{
884                                        name = "{variety}_{name}";
885                                        monomorphic = ["env"];
886                                        variety = "{variety}";
887                                        {ancestors}
888                                        {nude}
889                                    }}
890                                "#
891                                )
892                            })
893                            .format(", ");
894                        let _ = write!(&mut decls, "\n[@@deriving show, eq, ord, {visitors}]");
895                    };
896                    decls
897                }
898            };
899            let placeholder = format!("(* __REPLACE{i}__ *)");
900            template = template.replace(&placeholder, &generated);
901        }
902
903        fs::write(&self.target, template)
904            .with_context(|| format!("Failed to write generated file {}", self.target.display()))?;
905        Ok(())
906    }
907}
908
909fn main() -> Result<()> {
910    let dir = PathBuf::from("src/bin/generate-ml");
911    let charon_llbc = dir.join("charon-itself.llbc");
912    let reuse_llbc = std::env::var("CHARON_ML_REUSE_LLBC").is_ok(); // Useful when developping
913    if !reuse_llbc {
914        // Call charon on itself
915        let mut cmd = Command::cargo_bin("charon")?;
916        cmd.arg("--cargo-arg=--lib");
917        cmd.arg("--hide-marker-traits");
918        cmd.arg("--dest-file");
919        cmd.arg(&charon_llbc);
920        let output = cmd.output()?;
921
922        if !output.status.success() {
923            let stderr = String::from_utf8(output.stderr.clone())?;
924            bail!("Compilation failed: {stderr}")
925        }
926    }
927
928    let crate_data: TranslatedCrate = charon_lib::deserialize_llbc(&charon_llbc)?;
929    let output_dir = if std::env::var("IN_CI").as_deref() == Ok("1") {
930        dir.join("generated")
931    } else {
932        dir.join("../../../../charon-ml/src/generated")
933    };
934    generate_ml(crate_data, dir.join("templates"), output_dir)
935}
936
937fn generate_ml(
938    crate_data: TranslatedCrate,
939    template_dir: PathBuf,
940    output_dir: PathBuf,
941) -> anyhow::Result<()> {
942    let manual_type_impls = &[
943        // Hand-written because we replace the `FileId` with the corresponding file.
944        ("FileId", "file"),
945        // Hand-written because the rust version is an enum with custom (de)serialization
946        // functions.
947        (
948            "ScalarValue",
949            indoc!(
950                "
951                (* Note that we use unbounded integers everywhere.
952                   We then harcode the boundaries for the different types.
953                 *)
954                { value : big_int; int_ty : integer_type }
955                "
956            ),
957        ),
958        // Hand-written because we encode sequences differently.
959        // TODO: encode sequences identically.
960        (
961            "charon_lib::ast::llbc_ast::RawStatement",
962            indoc!(
963                "
964                | Assign of place * rvalue
965                | SetDiscriminant of place * variant_id
966                | StorageLive of local_id
967                | StorageDead of local_id
968                | Deinit of place
969                | Drop of place
970                | Assert of assertion
971                | Call of call
972                | Abort of abort_kind
973                | Return
974                | Break of int
975                    (** Break to (outer) loop. The [int] identifies the loop to break to:
976                        * 0: break to the first outer loop (the current loop)
977                        * 1: break to the second outer loop
978                        * ...
979                        *)
980                | Continue of int
981                    (** Continue to (outer) loop. The loop identifier works
982                        the same way as for {!Break} *)
983                | Nop
984                | Sequence of statement * statement
985                | Switch of switch
986                | Loop of statement
987                | Error of string
988                "
989            ),
990        ),
991        // Hand-written because we encode sequences differently.
992        ("charon_lib::ast::llbc_ast::Block", "statement"),
993        // Handwritten because we use `indexed_var` as a hack to be able to reuse field names.
994        // TODO: remove the need for this hack.
995        ("RegionVar", "(region_id, string option) indexed_var"),
996        ("TypeVar", "(type_var_id, string) indexed_var"),
997    ];
998    let manual_json_impls = &[
999        // Hand-written because we filter out `None` values.
1000        (
1001            "Vector",
1002            indoc!(
1003                r#"
1004                | js ->
1005                    let* list = list_of_json (option_of_json arg1_of_json) ctx js in
1006                    Ok (List.filter_map (fun x -> x) list)
1007                "#
1008            ),
1009        ),
1010        // Hand-written because we replace the `FileId` with the corresponding file name.
1011        (
1012            "FileId",
1013            indoc!(
1014                r#"
1015                | json ->
1016                    let* file_id = FileId.id_of_json ctx json in
1017                    let file = FileId.Map.find file_id ctx in
1018                    Ok file
1019                "#,
1020            ),
1021        ),
1022        // Hand-written because the rust version is an enum with custom (de)serialization
1023        // functions.
1024        (
1025            "ScalarValue",
1026            indoc!(
1027                r#"
1028                | `Assoc [ (ty, bi) ] ->
1029                    let big_int_of_json (js : json) : (big_int, string) result =
1030                      combine_error_msgs js __FUNCTION__
1031                        (match js with
1032                        | `Int i -> Ok (Z.of_int i)
1033                        | `String is -> Ok (Z.of_string is)
1034                        | _ -> Error "")
1035                    in
1036                    let* value = big_int_of_json bi in
1037                    let* int_ty = integer_type_of_json ctx (`String ty) in
1038                    let sv = { value; int_ty } in
1039                    if not (check_scalar_value_in_range sv) then
1040                      raise (Failure ("Scalar value not in range: " ^ show_scalar_value sv));
1041                    Ok sv
1042                "#
1043            ),
1044        ),
1045        // Hand-written because we encode sequences differently.
1046        (
1047            "charon_lib::ast::llbc_ast::Block",
1048            indoc!(
1049                r#"
1050                | `Assoc [ ("span", span); ("statements", statements) ] -> begin
1051                    let* span = span_of_json ctx span in
1052                    let* statements =
1053                      list_of_json statement_of_json ctx statements
1054                    in
1055                    match List.rev statements with
1056                    | [] -> Ok { span; content = Nop; comments_before = [] }
1057                    | last :: rest ->
1058                        let seq =
1059                          List.fold_left
1060                            (fun acc st -> { span = st.span; content = Sequence (st, acc); comments_before = [] })
1061                            last rest
1062                        in
1063                        Ok seq
1064                  end
1065                "#
1066            ),
1067        ),
1068    ];
1069    // Types for which we don't want to generate a type at all.
1070    let dont_generate_ty = &[
1071        "ItemOpacity",
1072        "PredicateOrigin",
1073        "TraitTypeConstraintId",
1074        "Ty",
1075        "Vector",
1076    ];
1077    // Types that we don't want visitors to go into.
1078    let opaque_for_visitor = &["Name"];
1079    let ctx = GenerateCtx::new(
1080        &crate_data,
1081        manual_type_impls,
1082        manual_json_impls,
1083        opaque_for_visitor,
1084    );
1085
1086    // Compute the sets of types to be put in each module.
1087    let manually_implemented: HashSet<_> = [
1088        "ItemOpacity",
1089        "PredicateOrigin",
1090        "Ty", // We exclude it since `TyKind` is renamed to `ty`
1091        "Opaque",
1092        "Body",
1093        "FunDecl",
1094        "TranslatedCrate",
1095    ]
1096    .iter()
1097    .map(|name| ctx.id_from_name(name))
1098    .collect();
1099
1100    // Compute type sets for json deserializers.
1101    let (gast_types, llbc_types, ullbc_types) = {
1102        let llbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::llbc_ast::Statement");
1103        let ullbc_types: HashSet<_> = ctx.children_of("charon_lib::ast::ullbc_ast::BodyContents");
1104        let common_types: HashSet<_> = llbc_types.intersection(&ullbc_types).copied().collect();
1105
1106        let llbc_types: HashSet<_> = llbc_types
1107            .difference(&common_types.union(&manually_implemented).copied().collect())
1108            .copied()
1109            .collect();
1110        let ullbc_types: HashSet<_> = ullbc_types
1111            .difference(&common_types.union(&manually_implemented).copied().collect())
1112            .copied()
1113            .collect();
1114
1115        let body_specific_types: HashSet<_> = llbc_types.union(&ullbc_types).copied().collect();
1116        let gast_types: HashSet<_> = ctx
1117            .children_of("TranslatedCrate")
1118            .difference(
1119                &body_specific_types
1120                    .union(&manually_implemented)
1121                    .copied()
1122                    .collect(),
1123            )
1124            .copied()
1125            .collect();
1126
1127        (gast_types, llbc_types, ullbc_types)
1128    };
1129
1130    let mut processed_tys: HashSet<TypeDeclId> = dont_generate_ty
1131        .iter()
1132        .map(|name| ctx.id_from_name(name))
1133        .collect();
1134    // Each call to this will return the children of the listed types that haven't been returned
1135    // yet. By calling it in dependency order, this allows to organize types into files without
1136    // having to list them all.
1137    let mut markers_from_children = |ctx: &GenerateCtx, markers: &[_]| {
1138        markers
1139            .iter()
1140            .copied()
1141            .map(|(kind, type_names)| {
1142                let types: HashSet<_> = ctx.children_of_many(type_names);
1143                let unprocessed_types: HashSet<_> =
1144                    types.difference(&processed_tys).copied().collect();
1145                processed_tys.extend(unprocessed_types.iter().copied());
1146                (kind, unprocessed_types)
1147            })
1148            .collect()
1149    };
1150
1151    #[rustfmt::skip]
1152    let generate_code_for = vec![
1153        GenerateCodeFor {
1154            template: template_dir.join("Meta.ml"),
1155            target: output_dir.join("Generated_Meta.ml"),
1156            markers: markers_from_children(&ctx, &[
1157                (GenerationKind::TypeDecl(None), &[
1158                    "File",
1159                    "Span",
1160                    "AttrInfo",
1161                ]),
1162            ]),
1163        },
1164        GenerateCodeFor {
1165            template: template_dir.join("Values.ml"),
1166            target: output_dir.join("Generated_Values.ml"),
1167            markers: markers_from_children(&ctx, &[
1168                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1169                    ancestors: &["big_int"],
1170                    name: "literal",
1171                    reduce: true,
1172                    extra_types: &[],
1173                })), &[
1174                    "Literal",
1175                    "IntegerTy",
1176                    "LiteralTy",
1177                ]),
1178            ]),
1179        },
1180        GenerateCodeFor {
1181            template: template_dir.join("Types.ml"),
1182            target: output_dir.join("Generated_Types.ml"),
1183            markers: markers_from_children(&ctx, &[
1184                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1185                    ancestors: &["literal"],
1186                    name: "const_generic",
1187                    reduce: true,
1188                    extra_types: &[],
1189                })), &[
1190                    "RegionId",
1191                    "TypeVarId",
1192                    "ConstGeneric",
1193                    "TraitClauseId",
1194                    "DeBruijnVar",
1195                    "AnyTransId",
1196                ]),
1197                // Can't merge into above because aeneas uses the above alongside their own partial
1198                // copy of `ty`, which causes method type clashes.
1199                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1200                    ancestors: &["ty_base_base"],
1201                    name: "ty",
1202                    reduce: false,
1203                    extra_types: &[],
1204                })), &[
1205                    "TyKind",
1206                    "TraitImplRef",
1207                    "FunDeclRef",
1208                    "GlobalDeclRef",
1209                ]),
1210                // TODO: can't merge into above because of field name clashes (`types`, `regions` etc).
1211                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1212                    ancestors: &["ty"],
1213                    name: "type_decl",
1214                    reduce: false,
1215                    extra_types: &[
1216                        "span", "attr_info"
1217                    ],
1218                })), &[
1219                    "Binder",
1220                    "AbortKind",
1221                    "TypeDecl",
1222                ]),
1223            ]),
1224        },
1225        GenerateCodeFor {
1226            template: template_dir.join("Expressions.ml"),
1227            target: output_dir.join("Generated_Expressions.ml"),
1228            markers: markers_from_children(&ctx, &[
1229                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1230                    ancestors: &["type_decl"],
1231                    name: "rvalue",
1232                    reduce: false,
1233                    extra_types: &[],
1234                })), &[
1235                    "Rvalue",
1236                ]),
1237            ]),
1238        },
1239        GenerateCodeFor {
1240            template: template_dir.join("GAst.ml"),
1241            target: output_dir.join("Generated_GAst.ml"),
1242            markers: markers_from_children(&ctx, &[
1243                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1244                    ancestors: &["rvalue"],
1245                    name: "fun_sig",
1246                    reduce: false,
1247                    extra_types: &[],
1248                })), &[
1249                    "Call",
1250                    "Assert",
1251                    "ItemKind",
1252                    "Locals",
1253                    "FunSig",
1254                ]),
1255                // These have to be kept separate to avoid field name clashes
1256                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1257                    ancestors: &["fun_sig"],
1258                    name: "global_decl",
1259                    reduce: false,
1260                    extra_types: &[],
1261                })), &[
1262                    "GlobalDecl",
1263                ]),
1264                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1265                    ancestors: &["global_decl"],
1266                    name: "trait_decl",
1267                    reduce: false,
1268                    extra_types: &[],
1269                })), &[
1270                    "TraitDecl",
1271                ]),
1272                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1273                    ancestors: &["trait_decl"],
1274                    name: "trait_impl",
1275                    reduce: false,
1276                    extra_types: &[],
1277                })), &[
1278                    "TraitImpl",
1279                ]),
1280                (GenerationKind::TypeDecl(None), &[
1281                    "CliOpts",
1282                    "GExprBody",
1283                    "DeclarationGroup",
1284                ]),
1285            ]),
1286        },
1287        GenerateCodeFor {
1288            template: template_dir.join("LlbcAst.ml"),
1289            target: output_dir.join("Generated_LlbcAst.ml"),
1290            markers: markers_from_children(&ctx, &[
1291                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1292                    name: "statement",
1293                    ancestors: &["trait_impl"],
1294                    reduce: false,
1295                    extra_types: &[],
1296                })), &[
1297                    "charon_lib::ast::llbc_ast::Statement",
1298                ]),
1299            ]),
1300        },
1301        GenerateCodeFor {
1302            template: template_dir.join("UllbcAst.ml"),
1303            target: output_dir.join("Generated_UllbcAst.ml"),
1304            markers: markers_from_children(&ctx, &[
1305                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1306                    ancestors: &["trait_impl"],
1307                    name: "statement",
1308                    reduce: false,
1309                    extra_types: &[],
1310                })), &[
1311                    "charon_lib::ast::ullbc_ast::Statement",
1312                    "charon_lib::ast::ullbc_ast::SwitchTargets",
1313                ]),
1314                // TODO: Can't merge with above because of field name clashes (`content` and `span`).
1315                (GenerationKind::TypeDecl(Some(DeriveVisitors {
1316                    ancestors: &["statement"],
1317                    name: "ullbc_ast",
1318                    reduce: false,
1319                    extra_types: &[],
1320                })), &[
1321                    "charon_lib::ast::ullbc_ast::BodyContents",
1322                ]),
1323            ]),
1324        },
1325        GenerateCodeFor {
1326            template: template_dir.join("GAstOfJson.ml"),
1327            target: output_dir.join("Generated_GAstOfJson.ml"),
1328            markers: vec![(GenerationKind::OfJson, gast_types)],
1329        },
1330        GenerateCodeFor {
1331            template: template_dir.join("LlbcOfJson.ml"),
1332            target: output_dir.join("Generated_LlbcOfJson.ml"),
1333            markers: vec![(GenerationKind::OfJson, llbc_types)],
1334        },
1335        GenerateCodeFor {
1336            template: template_dir.join("UllbcOfJson.ml"),
1337            target: output_dir.join("Generated_UllbcOfJson.ml"),
1338            markers: vec![(GenerationKind::OfJson, ullbc_types)],
1339        },
1340    ];
1341    for file in generate_code_for {
1342        file.generate(&ctx)?;
1343    }
1344    Ok(())
1345}