rustc_mir_transform/
check_enums.rs

1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_hir::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::bug;
5use rustc_middle::mir::visit::Visitor;
6use rustc_middle::mir::*;
7use rustc_middle::ty::layout::PrimitiveExt;
8use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
9use rustc_session::Session;
10use tracing::debug;
11
12/// This pass inserts checks for a valid enum discriminant where they are most
13/// likely to find UB, because checking everywhere like Miri would generate too
14/// much MIR.
15pub(super) struct CheckEnums;
16
17impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
18    fn is_enabled(&self, sess: &Session) -> bool {
19        sess.ub_checks()
20    }
21
22    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23        // This pass emits new panics. If for whatever reason we do not have a panic
24        // implementation, running this pass may cause otherwise-valid code to not compile.
25        if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
26            return;
27        }
28
29        let typing_env = body.typing_env(tcx);
30        let basic_blocks = body.basic_blocks.as_mut();
31        let local_decls = &mut body.local_decls;
32
33        // This operation inserts new blocks. Each insertion changes the Location for all
34        // statements/blocks after. Iterating or visiting the MIR in order would require updating
35        // our current location after every insertion. By iterating backwards, we dodge this issue:
36        // The only Locations that an insertion changes have already been handled.
37        for block in basic_blocks.indices().rev() {
38            for statement_index in (0..basic_blocks[block].statements.len()).rev() {
39                let location = Location { block, statement_index };
40                let statement = &basic_blocks[block].statements[statement_index];
41                let source_info = statement.source_info;
42
43                let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
44                finder.visit_statement(statement, location);
45
46                for check in finder.into_found_enums() {
47                    debug!("Inserting enum check");
48                    let new_block = split_block(basic_blocks, location);
49
50                    match check {
51                        EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
52                            insert_direct_enum_check(
53                                tcx,
54                                local_decls,
55                                basic_blocks,
56                                block,
57                                source_op,
58                                discr,
59                                op_size,
60                                valid_discrs,
61                                source_info,
62                                new_block,
63                            )
64                        }
65                        EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
66                            tcx,
67                            local_decls,
68                            &mut basic_blocks[block],
69                            source_info,
70                            new_block,
71                        ),
72                        EnumCheckType::WithNiche {
73                            source_op,
74                            discr,
75                            op_size,
76                            offset,
77                            valid_range,
78                        } => insert_niche_check(
79                            tcx,
80                            local_decls,
81                            &mut basic_blocks[block],
82                            source_op,
83                            valid_range,
84                            discr,
85                            op_size,
86                            offset,
87                            source_info,
88                            new_block,
89                        ),
90                    }
91                }
92            }
93        }
94    }
95
96    fn is_required(&self) -> bool {
97        true
98    }
99}
100
101/// Represent the different kind of enum checks we can insert.
102enum EnumCheckType<'tcx> {
103    /// We know we try to create an uninhabited enum from an inhabited variant.
104    Uninhabited,
105    /// We know the enum does no niche optimizations and can thus easily compute
106    /// the valid discriminants.
107    Direct {
108        source_op: Operand<'tcx>,
109        discr: TyAndSize<'tcx>,
110        op_size: Size,
111        valid_discrs: Vec<u128>,
112    },
113    /// We try to construct an enum that has a niche.
114    WithNiche {
115        source_op: Operand<'tcx>,
116        discr: TyAndSize<'tcx>,
117        op_size: Size,
118        offset: Size,
119        valid_range: WrappingRange,
120    },
121}
122
123struct TyAndSize<'tcx> {
124    pub ty: Ty<'tcx>,
125    pub size: Size,
126}
127
128/// A [Visitor] that finds the construction of enums and evaluates which checks
129/// we should apply.
130struct EnumFinder<'a, 'tcx> {
131    tcx: TyCtxt<'tcx>,
132    local_decls: &'a mut LocalDecls<'tcx>,
133    typing_env: TypingEnv<'tcx>,
134    enums: Vec<EnumCheckType<'tcx>>,
135}
136
137impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
138    fn new(
139        tcx: TyCtxt<'tcx>,
140        local_decls: &'a mut LocalDecls<'tcx>,
141        typing_env: TypingEnv<'tcx>,
142    ) -> Self {
143        EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
144    }
145
146    /// Returns the found enum creations and which checks should be inserted.
147    fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
148        self.enums
149    }
150}
151
152impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
153    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
154        if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
155            let ty::Adt(adt_def, _) = ty.kind() else {
156                return;
157            };
158            if !adt_def.is_enum() {
159                return;
160            }
161
162            let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
163                return;
164            };
165            let Ok(op_layout) = self
166                .tcx
167                .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
168            else {
169                return;
170            };
171
172            match enum_layout.variants {
173                Variants::Empty if op_layout.is_uninhabited() => return,
174                // An empty enum that tries to be constructed from an inhabited value, this
175                // is never correct.
176                Variants::Empty => {
177                    // The enum layout is uninhabited but we construct it from sth inhabited.
178                    // This is always UB.
179                    self.enums.push(EnumCheckType::Uninhabited);
180                }
181                // Construction of Single value enums is always fine.
182                Variants::Single { .. } => {}
183                // Construction of an enum with multiple variants but no niche optimizations.
184                Variants::Multiple {
185                    tag_encoding: TagEncoding::Direct,
186                    tag: Scalar::Initialized { value, .. },
187                    ..
188                } => {
189                    let valid_discrs =
190                        adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
191
192                    let discr =
193                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
194                    self.enums.push(EnumCheckType::Direct {
195                        source_op: op.to_copy(),
196                        discr,
197                        op_size: op_layout.size,
198                        valid_discrs,
199                    });
200                }
201                // Construction of an enum with multiple variants and niche optimizations.
202                Variants::Multiple {
203                    tag_encoding: TagEncoding::Niche { .. },
204                    tag: Scalar::Initialized { value, valid_range, .. },
205                    tag_field,
206                    ..
207                } => {
208                    let discr =
209                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
210                    self.enums.push(EnumCheckType::WithNiche {
211                        source_op: op.to_copy(),
212                        discr,
213                        op_size: op_layout.size,
214                        offset: enum_layout.fields.offset(tag_field.as_usize()),
215                        valid_range,
216                    });
217                }
218                _ => return,
219            }
220
221            self.super_rvalue(rvalue, location);
222        }
223    }
224}
225
226fn split_block(
227    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
228    location: Location,
229) -> BasicBlock {
230    let block_data = &mut basic_blocks[location.block];
231
232    // Drain every statement after this one and move the current terminator to a new basic block.
233    let new_block = BasicBlockData {
234        statements: block_data.statements.split_off(location.statement_index),
235        terminator: block_data.terminator.take(),
236        is_cleanup: block_data.is_cleanup,
237    };
238
239    basic_blocks.push(new_block)
240}
241
242/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value.
243fn insert_discr_cast_to_u128<'tcx>(
244    tcx: TyCtxt<'tcx>,
245    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
246    block_data: &mut BasicBlockData<'tcx>,
247    source_op: Operand<'tcx>,
248    discr: TyAndSize<'tcx>,
249    op_size: Size,
250    offset: Option<Size>,
251    source_info: SourceInfo,
252) -> Place<'tcx> {
253    let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
254        match size.bytes() {
255            1 => tcx.types.u8,
256            2 => tcx.types.u16,
257            4 => tcx.types.u32,
258            8 => tcx.types.u64,
259            16 => tcx.types.u128,
260            invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
261        }
262    };
263
264    let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
265        // The discriminant is less wide than the operand, cast the operand into
266        // [MaybeUninit; N] and then index into it.
267        let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
268        let array_len = op_size.bytes();
269        let mu_array_ty = Ty::new_array(tcx, mu, array_len);
270        let mu_array =
271            local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
272        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
273        block_data.statements.push(Statement {
274            source_info,
275            kind: StatementKind::Assign(Box::new((mu_array, rvalue))),
276        });
277
278        // Index into the array of MaybeUninit to get something that is actually
279        // as wide as the discriminant.
280        let offset = offset.unwrap_or(Size::ZERO);
281        let smaller_mu_array = mu_array.project_deeper(
282            &[ProjectionElem::Subslice {
283                from: offset.bytes(),
284                to: offset.bytes() + discr.size.bytes(),
285                from_end: false,
286            }],
287            tcx,
288        );
289
290        (CastKind::Transmute, Operand::Copy(smaller_mu_array))
291    } else {
292        let operand_int_ty = get_ty_for_size(tcx, op_size);
293
294        let op_as_int =
295            local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
296        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
297        block_data.statements.push(Statement {
298            source_info,
299            kind: StatementKind::Assign(Box::new((op_as_int, rvalue))),
300        });
301
302        (CastKind::IntToInt, Operand::Copy(op_as_int))
303    };
304
305    // Cast the resulting value to the actual discriminant integer type.
306    let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
307    let discr_in_discr_ty =
308        local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
309    block_data.statements.push(Statement {
310        source_info,
311        kind: StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
312    });
313
314    // Cast the discriminant to a u128 (base for comparisions of enum discriminants).
315    let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
316    let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
317    let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
318    block_data
319        .statements
320        .push(Statement { source_info, kind: StatementKind::Assign(Box::new((discr, rvalue))) });
321
322    discr
323}
324
325fn insert_direct_enum_check<'tcx>(
326    tcx: TyCtxt<'tcx>,
327    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
328    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
329    current_block: BasicBlock,
330    source_op: Operand<'tcx>,
331    discr: TyAndSize<'tcx>,
332    op_size: Size,
333    discriminants: Vec<u128>,
334    source_info: SourceInfo,
335    new_block: BasicBlock,
336) {
337    // Insert a new target block that is branched to in case of an invalid discriminant.
338    let invalid_discr_block_data = BasicBlockData::new(None, false);
339    let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340    let block_data = &mut basic_blocks[current_block];
341    let discr = insert_discr_cast_to_u128(
342        tcx,
343        local_decls,
344        block_data,
345        source_op,
346        discr,
347        op_size,
348        None,
349        source_info,
350    );
351
352    // Branch based on the discriminant value.
353    block_data.terminator = Some(Terminator {
354        source_info,
355        kind: TerminatorKind::SwitchInt {
356            discr: Operand::Copy(discr),
357            targets: SwitchTargets::new(
358                discriminants.into_iter().map(|discr| (discr, new_block)),
359                invalid_discr_block,
360            ),
361        },
362    });
363
364    // Abort in case of an invalid enum discriminant.
365    basic_blocks[invalid_discr_block].terminator = Some(Terminator {
366        source_info,
367        kind: TerminatorKind::Assert {
368            cond: Operand::Constant(Box::new(ConstOperand {
369                span: source_info.span,
370                user_ty: None,
371                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
372            })),
373            expected: true,
374            target: new_block,
375            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
376            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
377            // We never want to insert an unwind into unsafe code, because unwinding could
378            // make a failing UB check turn into much worse UB when we start unwinding.
379            unwind: UnwindAction::Unreachable,
380        },
381    });
382}
383
384fn insert_uninhabited_enum_check<'tcx>(
385    tcx: TyCtxt<'tcx>,
386    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
387    block_data: &mut BasicBlockData<'tcx>,
388    source_info: SourceInfo,
389    new_block: BasicBlock,
390) {
391    let is_ok: Place<'_> =
392        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
393    block_data.statements.push(Statement {
394        source_info,
395        kind: StatementKind::Assign(Box::new((
396            is_ok,
397            Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
398                span: source_info.span,
399                user_ty: None,
400                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
401            }))),
402        ))),
403    });
404
405    block_data.terminator = Some(Terminator {
406        source_info,
407        kind: TerminatorKind::Assert {
408            cond: Operand::Copy(is_ok),
409            expected: true,
410            target: new_block,
411            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
412                ConstOperand {
413                    span: source_info.span,
414                    user_ty: None,
415                    const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
416                },
417            )))),
418            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
419            // We never want to insert an unwind into unsafe code, because unwinding could
420            // make a failing UB check turn into much worse UB when we start unwinding.
421            unwind: UnwindAction::Unreachable,
422        },
423    });
424}
425
426fn insert_niche_check<'tcx>(
427    tcx: TyCtxt<'tcx>,
428    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
429    block_data: &mut BasicBlockData<'tcx>,
430    source_op: Operand<'tcx>,
431    valid_range: WrappingRange,
432    discr: TyAndSize<'tcx>,
433    op_size: Size,
434    offset: Size,
435    source_info: SourceInfo,
436    new_block: BasicBlock,
437) {
438    let discr = insert_discr_cast_to_u128(
439        tcx,
440        local_decls,
441        block_data,
442        source_op,
443        discr,
444        op_size,
445        Some(offset),
446        source_info,
447    );
448
449    // Compare the discriminant agains the valid_range.
450    let start_const = Operand::Constant(Box::new(ConstOperand {
451        span: source_info.span,
452        user_ty: None,
453        const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
454    }));
455    let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
456        span: source_info.span,
457        user_ty: None,
458        const_: Const::Val(
459            ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
460            tcx.types.u128,
461        ),
462    }));
463
464    let discr_diff: Place<'_> =
465        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
466    block_data.statements.push(Statement {
467        source_info,
468        kind: StatementKind::Assign(Box::new((
469            discr_diff,
470            Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
471        ))),
472    });
473
474    let is_ok: Place<'_> =
475        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
476    block_data.statements.push(Statement {
477        source_info,
478        kind: StatementKind::Assign(Box::new((
479            is_ok,
480            Rvalue::BinaryOp(
481                // This is a `WrappingRange`, so make sure to get the wrapping right.
482                BinOp::Le,
483                Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
484            ),
485        ))),
486    });
487
488    block_data.terminator = Some(Terminator {
489        source_info,
490        kind: TerminatorKind::Assert {
491            cond: Operand::Copy(is_ok),
492            expected: true,
493            target: new_block,
494            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
495            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
496            // We never want to insert an unwind into unsafe code, because unwinding could
497            // make a failing UB check turn into much worse UB when we start unwinding.
498            unwind: UnwindAction::Unreachable,
499        },
500    });
501}