charon_lib/name_matcher/
parser.rs

1use std::{fmt, str::FromStr};
2
3use itertools::Itertools;
4use nom::{
5    bytes::complete::{tag, take_while},
6    character::complete::{multispace0, multispace1},
7    combinator::{map_res, success},
8    error::ParseError,
9    multi::separated_list0,
10    sequence::{delimited, preceded},
11    Parser,
12};
13use nom_supreme::{error::ErrorTree, ParserExt};
14
15use super::{PatElem, PatTy, Pattern};
16use crate::ast::RefKind;
17
18type ParseResult<'a, T> = nom::IResult<&'a str, T, ErrorTree<&'a str>>;
19
20/// Extra methods on parsers.
21trait ParserExtExt<I, O, E>: Parser<I, O, E> + Sized
22where
23    I: Clone,
24    E: ParseError<I>,
25{
26    fn followed_by<F, O2>(self, suffix: F) -> impl Parser<I, O, E>
27    where
28        F: Parser<I, O2, E>,
29    {
30        self.terminated(suffix)
31    }
32}
33impl<I, O, E, P> ParserExtExt<I, O, E> for P
34where
35    I: Clone,
36    E: ParseError<I>,
37    P: Parser<I, O, E>,
38{
39}
40
41/// The entry point for this module: parses a string into a `Pattern`.
42impl FromStr for Pattern {
43    type Err = ErrorTree<String>;
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        parse_pattern_complete(s)
46    }
47}
48
49fn parse_pattern_complete(i: &str) -> Result<Pattern, ErrorTree<String>> {
50    nom_supreme::final_parser::final_parser(parse_pattern)(i)
51        .map_err(|e: ErrorTree<_>| e.map_locations(|s: &str| s.to_string()))
52}
53
54fn parse_pattern(i: &str) -> ParseResult<'_, Pattern> {
55    separated_list0(tag("::").followed_by(multispace0), parse_pat_elem)
56        .map(|elems| Pattern { elems })
57        .parse(i)
58}
59
60impl fmt::Display for Pattern {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        self.elems.iter().format("::").fmt(f)
63    }
64}
65
66impl fmt::Debug for Pattern {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        write!(f, "{self}")
69    }
70}
71
72fn parse_pat_elem(i: &str) -> ParseResult<'_, PatElem> {
73    let parse_glob = tag("*").map(|_| PatElem::Glob);
74    parse_glob
75        .or(parse_impl_elem)
76        .or(parse_simple_elem)
77        .parse(i)
78}
79
80fn parse_simple_elem(i: &str) -> ParseResult<'_, PatElem> {
81    let ident = take_while(|c: char| c.is_alphanumeric() || c == '_');
82    let (i, ident) = ident.followed_by(multispace0).parse(i)?;
83    if ident == "_" {
84        success(PatElem::Glob).parse(i)
85    } else {
86        let args = delimited(
87            tag("<").followed_by(multispace0),
88            separated_list0(
89                tag(",").followed_by(multispace0),
90                parse_pat_ty.followed_by(multispace0),
91            ),
92            tag(">"),
93        );
94        args.opt()
95            .map(|args| PatElem::Ident {
96                name: ident.to_string(),
97                generics: args.unwrap_or_default(),
98                is_trait: false,
99            })
100            .parse(i)
101    }
102}
103
104fn parse_impl_elem(i: &str) -> ParseResult<'_, PatElem> {
105    let for_ty = preceded(
106        tag("for").followed_by(multispace1),
107        parse_pat_ty.followed_by(multispace0),
108    );
109    let impl_contents = parse_pattern.followed_by(multispace0).and(for_ty.opt());
110    let impl_expr = tag("{").followed_by(multispace0).precedes(
111        delimited(
112            tag("impl").followed_by(multispace1.cut()).opt(),
113            impl_contents,
114            tag("}"),
115        )
116        .cut(),
117    );
118    map_res(impl_expr, |(mut pat, for_ty)| {
119        if let Some(for_ty) = for_ty {
120            let last_elem = pat
121                .elems
122                .last_mut()
123                .ok_or_else(|| anyhow::anyhow!("trait path must be nonempty"))?;
124            let PatElem::Ident {
125                generics, is_trait, ..
126            } = last_elem
127            else {
128                return Err(anyhow::anyhow!("trait path must end in an ident"));
129            };
130            // Set the type as the first generic arg.
131            generics.insert(0, for_ty);
132            *is_trait = true;
133        }
134        Ok(PatElem::Impl(pat.into()))
135    })
136    .parse(i)
137}
138
139impl fmt::Display for PatElem {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        match self {
142            PatElem::Ident {
143                name,
144                generics,
145                is_trait,
146            } => {
147                write!(f, "{name}")?;
148                let generics = generics.as_slice();
149                let (ty, generics) = if let [ty, generics @ ..] = generics
150                    && *is_trait
151                {
152                    (Some(ty), generics)
153                } else {
154                    (None, generics)
155                };
156                if !generics.is_empty() {
157                    write!(f, "<{}>", generics.iter().format(", "))?;
158                }
159                if let Some(ty) = ty {
160                    write!(f, " for {ty}")?;
161                }
162                Ok(())
163            }
164            PatElem::Impl(pat) => write!(f, "{{impl {pat}}}"),
165            PatElem::Glob => write!(f, "_"),
166        }
167    }
168}
169
170fn parse_pat_ty(i: &str) -> ParseResult<'_, PatTy> {
171    let mutability = tag("mut").followed_by(multispace0).opt().map(|mtbl| {
172        if mtbl.is_some() {
173            RefKind::Mut
174        } else {
175            RefKind::Shared
176        }
177    });
178    tag("&")
179        .followed_by(multispace0)
180        .precedes(mutability.and(parse_pat_ty))
181        .map(|(mtbl, ty)| PatTy::Ref(mtbl, ty.into()))
182        .or(parse_pattern.map(PatTy::Pat))
183        .parse(i)
184}
185
186impl fmt::Display for PatTy {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        match self {
189            PatTy::Pat(p) => write!(f, "{p}"),
190            PatTy::Ref(RefKind::Shared, ty) => write!(f, "&{ty}"),
191            PatTy::Ref(RefKind::Mut, ty) => write!(f, "&mut {ty}"),
192        }
193    }
194}
195
196#[test]
197fn test_roundtrip() {
198    let idempotent_test_strings = [
199        "crate::foo::bar",
200        "blah::_",
201        "blah::_foo",
202        "a::b::Type",
203        "a::b::Type<_, _>",
204        "Clone",
205        "usize",
206        "foo::{impl Clone for usize}::clone",
207        "foo::{impl Clone for &&usize}",
208        "foo::{impl PartialEq<_> for Type<_, _>}",
209        "foo::{impl PartialEq<usize> for Box<u8>}",
210        "foo::{impl foo::Trait<core::option::Option<_>> for alloc::boxed::Box<_>}::method",
211    ];
212    let other_test_strings = [
213        ("blah::*", "blah::_"),
214        ("crate  ::  foo  ::bar ", "crate::foo::bar"),
215        ("a::b::Type < _  ,  _ >", "a::b::Type<_, _>"),
216        ("{ impl  Clone  for  usize }", "{impl Clone for usize}"),
217        ("{Clone for usize}", "{impl Clone for usize}"),
218    ];
219    let failures = [
220        "{implClone for usize}",
221        "{impl Clone forusize}",
222        "foo::{impl  for alloc::boxed::Box<_>}::method",
223        "foo::{impl foo::_ for alloc::boxed::Box<_>}::method",
224        "foo::{impl &Clone for usize}",
225    ];
226
227    let test_strings = idempotent_test_strings
228        .into_iter()
229        .map(|s| (s, s))
230        .chain(other_test_strings);
231    for (input, expected) in test_strings {
232        let pat = Pattern::parse(input).map_err(|e| e.to_string()).unwrap();
233        assert_eq!(pat.to_string(), expected);
234    }
235
236    for input in failures {
237        assert!(
238            Pattern::parse(input).is_err(),
239            "Pattern parsed correctly but shouldn't: `{input}`"
240        );
241    }
242}