1use std::{fmt, str::FromStr};
2
3use itertools::Itertools;
4use nom::{
5 bytes::complete::{tag, take_while},
6 character::complete::{multispace0, multispace1},
7 combinator::{map_res, success},
8 error::ParseError,
9 multi::separated_list0,
10 sequence::{delimited, preceded},
11 Parser,
12};
13use nom_supreme::{error::ErrorTree, ParserExt};
14
15use super::{PatElem, PatTy, Pattern};
16use crate::ast::RefKind;
17
18type ParseResult<'a, T> = nom::IResult<&'a str, T, ErrorTree<&'a str>>;
19
20trait ParserExtExt<I, O, E>: Parser<I, O, E> + Sized
22where
23 I: Clone,
24 E: ParseError<I>,
25{
26 fn followed_by<F, O2>(self, suffix: F) -> impl Parser<I, O, E>
27 where
28 F: Parser<I, O2, E>,
29 {
30 self.terminated(suffix)
31 }
32}
33impl<I, O, E, P> ParserExtExt<I, O, E> for P
34where
35 I: Clone,
36 E: ParseError<I>,
37 P: Parser<I, O, E>,
38{
39}
40
41impl FromStr for Pattern {
43 type Err = ErrorTree<String>;
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 parse_pattern_complete(s)
46 }
47}
48
49fn parse_pattern_complete(i: &str) -> Result<Pattern, ErrorTree<String>> {
50 nom_supreme::final_parser::final_parser(parse_pattern)(i)
51 .map_err(|e: ErrorTree<_>| e.map_locations(|s: &str| s.to_string()))
52}
53
54fn parse_pattern(i: &str) -> ParseResult<'_, Pattern> {
55 separated_list0(tag("::").followed_by(multispace0), parse_pat_elem)
56 .map(|elems| Pattern { elems })
57 .parse(i)
58}
59
60impl fmt::Display for Pattern {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 self.elems.iter().format("::").fmt(f)
63 }
64}
65
66impl fmt::Debug for Pattern {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 write!(f, "{self}")
69 }
70}
71
72fn parse_pat_elem(i: &str) -> ParseResult<'_, PatElem> {
73 let parse_glob = tag("*").map(|_| PatElem::Glob);
74 parse_glob
75 .or(parse_impl_elem)
76 .or(parse_simple_elem)
77 .parse(i)
78}
79
80fn parse_simple_elem(i: &str) -> ParseResult<'_, PatElem> {
81 let ident = take_while(|c: char| c.is_alphanumeric() || c == '_');
82 let (i, ident) = ident.followed_by(multispace0).parse(i)?;
83 if ident == "_" {
84 success(PatElem::Glob).parse(i)
85 } else {
86 let args = delimited(
87 tag("<").followed_by(multispace0),
88 separated_list0(
89 tag(",").followed_by(multispace0),
90 parse_pat_ty.followed_by(multispace0),
91 ),
92 tag(">"),
93 );
94 args.opt()
95 .map(|args| PatElem::Ident {
96 name: ident.to_string(),
97 generics: args.unwrap_or_default(),
98 is_trait: false,
99 })
100 .parse(i)
101 }
102}
103
104fn parse_impl_elem(i: &str) -> ParseResult<'_, PatElem> {
105 let for_ty = preceded(
106 tag("for").followed_by(multispace1),
107 parse_pat_ty.followed_by(multispace0),
108 );
109 let impl_contents = parse_pattern.followed_by(multispace0).and(for_ty.opt());
110 let impl_expr = tag("{").followed_by(multispace0).precedes(
111 delimited(
112 tag("impl").followed_by(multispace1.cut()).opt(),
113 impl_contents,
114 tag("}"),
115 )
116 .cut(),
117 );
118 map_res(impl_expr, |(mut pat, for_ty)| {
119 if let Some(for_ty) = for_ty {
120 let last_elem = pat
121 .elems
122 .last_mut()
123 .ok_or_else(|| anyhow::anyhow!("trait path must be nonempty"))?;
124 let PatElem::Ident {
125 generics, is_trait, ..
126 } = last_elem
127 else {
128 return Err(anyhow::anyhow!("trait path must end in an ident"));
129 };
130 generics.insert(0, for_ty);
132 *is_trait = true;
133 }
134 Ok(PatElem::Impl(pat.into()))
135 })
136 .parse(i)
137}
138
139impl fmt::Display for PatElem {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 match self {
142 PatElem::Ident {
143 name,
144 generics,
145 is_trait,
146 } => {
147 write!(f, "{name}")?;
148 let generics = generics.as_slice();
149 let (ty, generics) = if let [ty, generics @ ..] = generics
150 && *is_trait
151 {
152 (Some(ty), generics)
153 } else {
154 (None, generics)
155 };
156 if !generics.is_empty() {
157 write!(f, "<{}>", generics.iter().format(", "))?;
158 }
159 if let Some(ty) = ty {
160 write!(f, " for {ty}")?;
161 }
162 Ok(())
163 }
164 PatElem::Impl(pat) => write!(f, "{{impl {pat}}}"),
165 PatElem::Glob => write!(f, "_"),
166 }
167 }
168}
169
170fn parse_pat_ty(i: &str) -> ParseResult<'_, PatTy> {
171 let mutability = tag("mut").followed_by(multispace0).opt().map(|mtbl| {
172 if mtbl.is_some() {
173 RefKind::Mut
174 } else {
175 RefKind::Shared
176 }
177 });
178 tag("&")
179 .followed_by(multispace0)
180 .precedes(mutability.and(parse_pat_ty))
181 .map(|(mtbl, ty)| PatTy::Ref(mtbl, ty.into()))
182 .or(parse_pattern.map(PatTy::Pat))
183 .parse(i)
184}
185
186impl fmt::Display for PatTy {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 PatTy::Pat(p) => write!(f, "{p}"),
190 PatTy::Ref(RefKind::Shared, ty) => write!(f, "&{ty}"),
191 PatTy::Ref(RefKind::Mut, ty) => write!(f, "&mut {ty}"),
192 }
193 }
194}
195
196#[test]
197fn test_roundtrip() {
198 let idempotent_test_strings = [
199 "crate::foo::bar",
200 "blah::_",
201 "blah::_foo",
202 "a::b::Type",
203 "a::b::Type<_, _>",
204 "Clone",
205 "usize",
206 "foo::{impl Clone for usize}::clone",
207 "foo::{impl Clone for &&usize}",
208 "foo::{impl PartialEq<_> for Type<_, _>}",
209 "foo::{impl PartialEq<usize> for Box<u8>}",
210 "foo::{impl foo::Trait<core::option::Option<_>> for alloc::boxed::Box<_>}::method",
211 ];
212 let other_test_strings = [
213 ("blah::*", "blah::_"),
214 ("crate :: foo ::bar ", "crate::foo::bar"),
215 ("a::b::Type < _ , _ >", "a::b::Type<_, _>"),
216 ("{ impl Clone for usize }", "{impl Clone for usize}"),
217 ("{Clone for usize}", "{impl Clone for usize}"),
218 ];
219 let failures = [
220 "{implClone for usize}",
221 "{impl Clone forusize}",
222 "foo::{impl for alloc::boxed::Box<_>}::method",
223 "foo::{impl foo::_ for alloc::boxed::Box<_>}::method",
224 "foo::{impl &Clone for usize}",
225 ];
226
227 let test_strings = idempotent_test_strings
228 .into_iter()
229 .map(|s| (s, s))
230 .chain(other_test_strings);
231 for (input, expected) in test_strings {
232 let pat = Pattern::parse(input).map_err(|e| e.to_string()).unwrap();
233 assert_eq!(pat.to_string(), expected);
234 }
235
236 for input in failures {
237 assert!(
238 Pattern::parse(input).is_err(),
239 "Pattern parsed correctly but shouldn't: `{input}`"
240 );
241 }
242}