rustc_builtin_macros/deriving/cmp/partial_eq.rs
1use rustc_ast::ptr::P;
2use rustc_ast::{BinOpKind, BorrowKind, Expr, ExprKind, MetaItem, Mutability};
3use rustc_expand::base::{Annotatable, ExtCtxt};
4use rustc_span::{Span, sym};
5use thin_vec::thin_vec;
6
7use crate::deriving::generic::ty::*;
8use crate::deriving::generic::*;
9use crate::deriving::{path_local, path_std};
10
11/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
12/// target item.
13pub(crate) fn expand_deriving_partial_eq(
14 cx: &ExtCtxt<'_>,
15 span: Span,
16 mitem: &MetaItem,
17 item: &Annotatable,
18 push: &mut dyn FnMut(Annotatable),
19 is_const: bool,
20) {
21 let structural_trait_def = TraitDef {
22 span,
23 path: path_std!(marker::StructuralPartialEq),
24 skip_path_as_bound: true, // crucial!
25 needs_copy_as_bound_if_packed: false,
26 additional_bounds: Vec::new(),
27 // We really don't support unions, but that's already checked by the impl generated below;
28 // a second check here would lead to redundant error messages.
29 supports_unions: true,
30 methods: Vec::new(),
31 associated_types: Vec::new(),
32 is_const: false,
33 is_staged_api_crate: cx.ecfg.features.staged_api(),
34 };
35 structural_trait_def.expand(cx, mitem, item, push);
36
37 // No need to generate `ne`, the default suffices, and not generating it is
38 // faster.
39 let methods = vec![MethodDef {
40 name: sym::eq,
41 generics: Bounds::empty(),
42 explicit_self: true,
43 nonself_args: vec![(self_ref(), sym::other)],
44 ret_ty: Path(path_local!(bool)),
45 attributes: thin_vec![cx.attr_word(sym::inline, span)],
46 fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
47 combine_substructure: combine_substructure(Box::new(|a, b, c| {
48 BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
49 })),
50 }];
51
52 let trait_def = TraitDef {
53 span,
54 path: path_std!(cmp::PartialEq),
55 skip_path_as_bound: false,
56 needs_copy_as_bound_if_packed: true,
57 additional_bounds: Vec::new(),
58 supports_unions: false,
59 methods,
60 associated_types: Vec::new(),
61 is_const,
62 is_staged_api_crate: cx.ecfg.features.staged_api(),
63 };
64 trait_def.expand(cx, mitem, item, push)
65}
66
67/// Generates the equality expression for a struct or enum variant when deriving
68/// `PartialEq`.
69///
70/// This function generates an expression that checks if all fields of a struct
71/// or enum variant are equal.
72/// - Scalar fields are compared first for efficiency, followed by compound
73/// fields.
74/// - If there are no fields, returns `true` (fieldless types are always equal).
75///
76/// Whether a field is considered "scalar" is determined by comparing the symbol
77/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
78/// This check is based on the type's symbol.
79///
80/// ### Example 1
81/// ```
82/// #[derive(PartialEq)]
83/// struct i32;
84///
85/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
86/// // the primitive), it will not be treated as scalar. The function will still
87/// // check equality of `field_2` first because the symbol matches `i32`.
88/// #[derive(PartialEq)]
89/// struct Struct {
90/// field_1: &'static str,
91/// field_2: i32,
92/// }
93/// ```
94///
95/// ### Example 2
96/// ```
97/// mod ty {
98/// pub type i32 = i32;
99/// }
100///
101/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
102/// // However, the function will not reorder the fields because the symbol for
103/// // `ty::i32` does not match the symbol for the primitive `i32`
104/// // ("ty::i32" != "i32").
105/// #[derive(PartialEq)]
106/// struct Struct {
107/// field_1: &'static str,
108/// field_2: ty::i32,
109/// }
110/// ```
111///
112/// For enums, the discriminant is compared first, then the rest of the fields.
113///
114/// # Panics
115///
116/// If called on static or all-fieldless enums/structs, which should not occur
117/// during derive expansion.
118fn get_substructure_equality_expr(
119 cx: &ExtCtxt<'_>,
120 span: Span,
121 substructure: &Substructure<'_>,
122) -> P<Expr> {
123 use SubstructureFields::*;
124
125 match substructure.fields {
126 EnumMatching(.., fields) | Struct(.., fields) => {
127 let combine = move |acc, field| {
128 let rhs = get_field_equality_expr(cx, field);
129 if let Some(lhs) = acc {
130 // Combine the previous comparison with the current field
131 // using logical AND.
132 return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
133 }
134 // Start the chain with the first field's comparison.
135 Some(rhs)
136 };
137
138 // First compare scalar fields, then compound fields, combining all
139 // with logical AND.
140 return fields
141 .iter()
142 .filter(|field| !field.maybe_scalar)
143 .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
144 // If there are no fields, treat as always equal.
145 .unwrap_or_else(|| cx.expr_bool(span, true));
146 }
147 EnumDiscr(disc, match_expr) => {
148 let lhs = get_field_equality_expr(cx, disc);
149 let Some(match_expr) = match_expr else {
150 return lhs;
151 };
152 // Compare the discriminant first (cheaper), then the rest of the
153 // fields.
154 return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
155 }
156 StaticEnum(..) => cx.dcx().span_bug(
157 span,
158 "unexpected static enum encountered during `derive(PartialEq)` expansion",
159 ),
160 StaticStruct(..) => cx.dcx().span_bug(
161 span,
162 "unexpected static struct encountered during `derive(PartialEq)` expansion",
163 ),
164 AllFieldlessEnum(..) => cx.dcx().span_bug(
165 span,
166 "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
167 ),
168 }
169}
170
171/// Generates an equality comparison expression for a single struct or enum
172/// field.
173///
174/// This function produces an AST expression that compares the `self` and
175/// `other` values for a field using `==`. It removes any leading references
176/// from both sides for readability. If the field is a block expression, it is
177/// wrapped in parentheses to ensure valid syntax.
178///
179/// # Panics
180///
181/// Panics if there are not exactly two arguments to compare (should be `self`
182/// and `other`).
183fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
184 let [rhs] = &field.other_selflike_exprs[..] else {
185 cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
186 };
187
188 cx.expr_binary(
189 field.span,
190 BinOpKind::Eq,
191 wrap_block_expr(cx, peel_refs(&field.self_expr)),
192 wrap_block_expr(cx, peel_refs(rhs)),
193 )
194}
195
196/// Removes all leading immutable references from an expression.
197///
198/// This is used to strip away any number of leading `&` from an expression
199/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
200/// references are preserved.
201fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
202 while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
203 expr = &inner;
204 }
205 expr.clone()
206}
207
208/// Wraps a block expression in parentheses to ensure valid AST in macro
209/// expansion output.
210///
211/// If the given expression is a block, it is wrapped in parentheses; otherwise,
212/// it is returned unchanged.
213fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
214 if matches!(&expr.kind, ExprKind::Block(..)) {
215 return cx.expr_paren(expr.span, expr);
216 }
217 expr
218}