rustc_builtin_macros/deriving/cmp/
partial_eq.rs

1use rustc_ast::ptr::P;
2use rustc_ast::{BinOpKind, BorrowKind, Expr, ExprKind, MetaItem, Mutability};
3use rustc_expand::base::{Annotatable, ExtCtxt};
4use rustc_span::{Span, sym};
5use thin_vec::thin_vec;
6
7use crate::deriving::generic::ty::*;
8use crate::deriving::generic::*;
9use crate::deriving::{path_local, path_std};
10
11/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
12/// target item.
13pub(crate) fn expand_deriving_partial_eq(
14    cx: &ExtCtxt<'_>,
15    span: Span,
16    mitem: &MetaItem,
17    item: &Annotatable,
18    push: &mut dyn FnMut(Annotatable),
19    is_const: bool,
20) {
21    let structural_trait_def = TraitDef {
22        span,
23        path: path_std!(marker::StructuralPartialEq),
24        skip_path_as_bound: true, // crucial!
25        needs_copy_as_bound_if_packed: false,
26        additional_bounds: Vec::new(),
27        // We really don't support unions, but that's already checked by the impl generated below;
28        // a second check here would lead to redundant error messages.
29        supports_unions: true,
30        methods: Vec::new(),
31        associated_types: Vec::new(),
32        is_const: false,
33    };
34    structural_trait_def.expand(cx, mitem, item, push);
35
36    // No need to generate `ne`, the default suffices, and not generating it is
37    // faster.
38    let methods = vec![MethodDef {
39        name: sym::eq,
40        generics: Bounds::empty(),
41        explicit_self: true,
42        nonself_args: vec![(self_ref(), sym::other)],
43        ret_ty: Path(path_local!(bool)),
44        attributes: thin_vec![cx.attr_word(sym::inline, span)],
45        fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
46        combine_substructure: combine_substructure(Box::new(|a, b, c| {
47            BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
48        })),
49    }];
50
51    let trait_def = TraitDef {
52        span,
53        path: path_std!(cmp::PartialEq),
54        skip_path_as_bound: false,
55        needs_copy_as_bound_if_packed: true,
56        additional_bounds: Vec::new(),
57        supports_unions: false,
58        methods,
59        associated_types: Vec::new(),
60        is_const,
61    };
62    trait_def.expand(cx, mitem, item, push)
63}
64
65/// Generates the equality expression for a struct or enum variant when deriving
66/// `PartialEq`.
67///
68/// This function generates an expression that checks if all fields of a struct
69/// or enum variant are equal.
70/// - Scalar fields are compared first for efficiency, followed by compound
71///   fields.
72/// - If there are no fields, returns `true` (fieldless types are always equal).
73///
74/// Whether a field is considered "scalar" is determined by comparing the symbol
75/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
76/// This check is based on the type's symbol.
77///
78/// ### Example 1
79/// ```
80/// #[derive(PartialEq)]
81/// struct i32;
82///
83/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
84/// // the primitive), it will not be treated as scalar. The function will still
85/// // check equality of `field_2` first because the symbol matches `i32`.
86/// #[derive(PartialEq)]
87/// struct Struct {
88///     field_1: &'static str,
89///     field_2: i32,
90/// }
91/// ```
92///
93/// ### Example 2
94/// ```
95/// mod ty {
96///     pub type i32 = i32;
97/// }
98///
99/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
100/// // However, the function will not reorder the fields because the symbol for
101/// // `ty::i32` does not match the symbol for the primitive `i32`
102/// // ("ty::i32" != "i32").
103/// #[derive(PartialEq)]
104/// struct Struct {
105///     field_1: &'static str,
106///     field_2: ty::i32,
107/// }
108/// ```
109///
110/// For enums, the discriminant is compared first, then the rest of the fields.
111///
112/// # Panics
113///
114/// If called on static or all-fieldless enums/structs, which should not occur
115/// during derive expansion.
116fn get_substructure_equality_expr(
117    cx: &ExtCtxt<'_>,
118    span: Span,
119    substructure: &Substructure<'_>,
120) -> P<Expr> {
121    use SubstructureFields::*;
122
123    match substructure.fields {
124        EnumMatching(.., fields) | Struct(.., fields) => {
125            let combine = move |acc, field| {
126                let rhs = get_field_equality_expr(cx, field);
127                if let Some(lhs) = acc {
128                    // Combine the previous comparison with the current field
129                    // using logical AND.
130                    return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
131                }
132                // Start the chain with the first field's comparison.
133                Some(rhs)
134            };
135
136            // First compare scalar fields, then compound fields, combining all
137            // with logical AND.
138            return fields
139                .iter()
140                .filter(|field| !field.maybe_scalar)
141                .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
142                // If there are no fields, treat as always equal.
143                .unwrap_or_else(|| cx.expr_bool(span, true));
144        }
145        EnumDiscr(disc, match_expr) => {
146            let lhs = get_field_equality_expr(cx, disc);
147            let Some(match_expr) = match_expr else {
148                return lhs;
149            };
150            // Compare the discriminant first (cheaper), then the rest of the
151            // fields.
152            return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
153        }
154        StaticEnum(..) => cx.dcx().span_bug(
155            span,
156            "unexpected static enum encountered during `derive(PartialEq)` expansion",
157        ),
158        StaticStruct(..) => cx.dcx().span_bug(
159            span,
160            "unexpected static struct encountered during `derive(PartialEq)` expansion",
161        ),
162        AllFieldlessEnum(..) => cx.dcx().span_bug(
163            span,
164            "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
165        ),
166    }
167}
168
169/// Generates an equality comparison expression for a single struct or enum
170/// field.
171///
172/// This function produces an AST expression that compares the `self` and
173/// `other` values for a field using `==`. It removes any leading references
174/// from both sides for readability. If the field is a block expression, it is
175/// wrapped in parentheses to ensure valid syntax.
176///
177/// # Panics
178///
179/// Panics if there are not exactly two arguments to compare (should be `self`
180/// and `other`).
181fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> P<Expr> {
182    let [rhs] = &field.other_selflike_exprs[..] else {
183        cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
184    };
185
186    cx.expr_binary(
187        field.span,
188        BinOpKind::Eq,
189        wrap_block_expr(cx, peel_refs(&field.self_expr)),
190        wrap_block_expr(cx, peel_refs(rhs)),
191    )
192}
193
194/// Removes all leading immutable references from an expression.
195///
196/// This is used to strip away any number of leading `&` from an expression
197/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
198/// references are preserved.
199fn peel_refs(mut expr: &P<Expr>) -> P<Expr> {
200    while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
201        expr = &inner;
202    }
203    expr.clone()
204}
205
206/// Wraps a block expression in parentheses to ensure valid AST in macro
207/// expansion output.
208///
209/// If the given expression is a block, it is wrapped in parentheses; otherwise,
210/// it is returned unchanged.
211fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: P<Expr>) -> P<Expr> {
212    if matches!(&expr.kind, ExprKind::Block(..)) {
213        return cx.expr_paren(expr.span, expr);
214    }
215    expr
216}