rustc_mir_transform/
check_enums.rs

1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_hir::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::bug;
5use rustc_middle::mir::visit::Visitor;
6use rustc_middle::mir::*;
7use rustc_middle::ty::layout::PrimitiveExt;
8use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
9use rustc_session::Session;
10use tracing::debug;
11
12/// This pass inserts checks for a valid enum discriminant where they are most
13/// likely to find UB, because checking everywhere like Miri would generate too
14/// much MIR.
15pub(super) struct CheckEnums;
16
17impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
18    fn is_enabled(&self, sess: &Session) -> bool {
19        sess.ub_checks()
20    }
21
22    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23        // This pass emits new panics. If for whatever reason we do not have a panic
24        // implementation, running this pass may cause otherwise-valid code to not compile.
25        if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
26            return;
27        }
28
29        let typing_env = body.typing_env(tcx);
30        let basic_blocks = body.basic_blocks.as_mut();
31        let local_decls = &mut body.local_decls;
32
33        // This operation inserts new blocks. Each insertion changes the Location for all
34        // statements/blocks after. Iterating or visiting the MIR in order would require updating
35        // our current location after every insertion. By iterating backwards, we dodge this issue:
36        // The only Locations that an insertion changes have already been handled.
37        for block in basic_blocks.indices().rev() {
38            for statement_index in (0..basic_blocks[block].statements.len()).rev() {
39                let location = Location { block, statement_index };
40                let statement = &basic_blocks[block].statements[statement_index];
41                let source_info = statement.source_info;
42
43                let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
44                finder.visit_statement(statement, location);
45
46                for check in finder.into_found_enums() {
47                    debug!("Inserting enum check");
48                    let new_block = split_block(basic_blocks, location);
49
50                    match check {
51                        EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
52                            insert_direct_enum_check(
53                                tcx,
54                                local_decls,
55                                basic_blocks,
56                                block,
57                                source_op,
58                                discr,
59                                op_size,
60                                valid_discrs,
61                                source_info,
62                                new_block,
63                            )
64                        }
65                        EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
66                            tcx,
67                            local_decls,
68                            &mut basic_blocks[block],
69                            source_info,
70                            new_block,
71                        ),
72                        EnumCheckType::WithNiche {
73                            source_op,
74                            discr,
75                            op_size,
76                            offset,
77                            valid_range,
78                        } => insert_niche_check(
79                            tcx,
80                            local_decls,
81                            &mut basic_blocks[block],
82                            source_op,
83                            valid_range,
84                            discr,
85                            op_size,
86                            offset,
87                            source_info,
88                            new_block,
89                        ),
90                    }
91                }
92            }
93        }
94    }
95
96    fn is_required(&self) -> bool {
97        true
98    }
99}
100
101/// Represent the different kind of enum checks we can insert.
102enum EnumCheckType<'tcx> {
103    /// We know we try to create an uninhabited enum from an inhabited variant.
104    Uninhabited,
105    /// We know the enum does no niche optimizations and can thus easily compute
106    /// the valid discriminants.
107    Direct {
108        source_op: Operand<'tcx>,
109        discr: TyAndSize<'tcx>,
110        op_size: Size,
111        valid_discrs: Vec<u128>,
112    },
113    /// We try to construct an enum that has a niche.
114    WithNiche {
115        source_op: Operand<'tcx>,
116        discr: TyAndSize<'tcx>,
117        op_size: Size,
118        offset: Size,
119        valid_range: WrappingRange,
120    },
121}
122
123#[derive(Debug, Copy, Clone)]
124struct TyAndSize<'tcx> {
125    pub ty: Ty<'tcx>,
126    pub size: Size,
127}
128
129/// A [Visitor] that finds the construction of enums and evaluates which checks
130/// we should apply.
131struct EnumFinder<'a, 'tcx> {
132    tcx: TyCtxt<'tcx>,
133    local_decls: &'a mut LocalDecls<'tcx>,
134    typing_env: TypingEnv<'tcx>,
135    enums: Vec<EnumCheckType<'tcx>>,
136}
137
138impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
139    fn new(
140        tcx: TyCtxt<'tcx>,
141        local_decls: &'a mut LocalDecls<'tcx>,
142        typing_env: TypingEnv<'tcx>,
143    ) -> Self {
144        EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
145    }
146
147    /// Returns the found enum creations and which checks should be inserted.
148    fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
149        self.enums
150    }
151}
152
153impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
154    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
155        if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
156            let ty::Adt(adt_def, _) = ty.kind() else {
157                return;
158            };
159            if !adt_def.is_enum() {
160                return;
161            }
162
163            let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
164                return;
165            };
166            let Ok(op_layout) = self
167                .tcx
168                .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
169            else {
170                return;
171            };
172
173            match enum_layout.variants {
174                Variants::Empty if op_layout.is_uninhabited() => return,
175                // An empty enum that tries to be constructed from an inhabited value, this
176                // is never correct.
177                Variants::Empty => {
178                    // The enum layout is uninhabited but we construct it from sth inhabited.
179                    // This is always UB.
180                    self.enums.push(EnumCheckType::Uninhabited);
181                }
182                // Construction of Single value enums is always fine.
183                Variants::Single { .. } => {}
184                // Construction of an enum with multiple variants but no niche optimizations.
185                Variants::Multiple {
186                    tag_encoding: TagEncoding::Direct,
187                    tag: Scalar::Initialized { value, .. },
188                    ..
189                } => {
190                    let valid_discrs =
191                        adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
192
193                    let discr =
194                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
195                    self.enums.push(EnumCheckType::Direct {
196                        source_op: op.to_copy(),
197                        discr,
198                        op_size: op_layout.size,
199                        valid_discrs,
200                    });
201                }
202                // Construction of an enum with multiple variants and niche optimizations.
203                Variants::Multiple {
204                    tag_encoding: TagEncoding::Niche { .. },
205                    tag: Scalar::Initialized { value, valid_range, .. },
206                    tag_field,
207                    ..
208                } => {
209                    let discr =
210                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
211                    self.enums.push(EnumCheckType::WithNiche {
212                        source_op: op.to_copy(),
213                        discr,
214                        op_size: op_layout.size,
215                        offset: enum_layout.fields.offset(tag_field.as_usize()),
216                        valid_range,
217                    });
218                }
219                _ => return,
220            }
221
222            self.super_rvalue(rvalue, location);
223        }
224    }
225}
226
227fn split_block(
228    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
229    location: Location,
230) -> BasicBlock {
231    let block_data = &mut basic_blocks[location.block];
232
233    // Drain every statement after this one and move the current terminator to a new basic block.
234    let new_block = BasicBlockData::new_stmts(
235        block_data.statements.split_off(location.statement_index),
236        block_data.terminator.take(),
237        block_data.is_cleanup,
238    );
239
240    basic_blocks.push(new_block)
241}
242
243/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value.
244fn insert_discr_cast_to_u128<'tcx>(
245    tcx: TyCtxt<'tcx>,
246    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
247    block_data: &mut BasicBlockData<'tcx>,
248    source_op: Operand<'tcx>,
249    discr: TyAndSize<'tcx>,
250    op_size: Size,
251    offset: Option<Size>,
252    source_info: SourceInfo,
253) -> Place<'tcx> {
254    let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
255        match size.bytes() {
256            1 => tcx.types.u8,
257            2 => tcx.types.u16,
258            4 => tcx.types.u32,
259            8 => tcx.types.u64,
260            16 => tcx.types.u128,
261            invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
262        }
263    };
264
265    let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
266        // The discriminant is less wide than the operand, cast the operand into
267        // [MaybeUninit; N] and then index into it.
268        let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
269        let array_len = op_size.bytes();
270        let mu_array_ty = Ty::new_array(tcx, mu, array_len);
271        let mu_array =
272            local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
273        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
274        block_data
275            .statements
276            .push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
277
278        // Index into the array of MaybeUninit to get something that is actually
279        // as wide as the discriminant.
280        let offset = offset.unwrap_or(Size::ZERO);
281        let smaller_mu_array = mu_array.project_deeper(
282            &[ProjectionElem::Subslice {
283                from: offset.bytes(),
284                to: offset.bytes() + discr.size.bytes(),
285                from_end: false,
286            }],
287            tcx,
288        );
289
290        (CastKind::Transmute, Operand::Copy(smaller_mu_array))
291    } else {
292        let operand_int_ty = get_ty_for_size(tcx, op_size);
293
294        let op_as_int =
295            local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
296        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
297        block_data.statements.push(Statement::new(
298            source_info,
299            StatementKind::Assign(Box::new((op_as_int, rvalue))),
300        ));
301
302        (CastKind::IntToInt, Operand::Copy(op_as_int))
303    };
304
305    // Cast the resulting value to the actual discriminant integer type.
306    let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
307    let discr_in_discr_ty =
308        local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
309    block_data.statements.push(Statement::new(
310        source_info,
311        StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
312    ));
313
314    // Cast the discriminant to a u128 (base for comparisons of enum discriminants).
315    let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
316    let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
317    let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
318    block_data
319        .statements
320        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
321
322    discr
323}
324
325fn insert_direct_enum_check<'tcx>(
326    tcx: TyCtxt<'tcx>,
327    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
328    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
329    current_block: BasicBlock,
330    source_op: Operand<'tcx>,
331    discr: TyAndSize<'tcx>,
332    op_size: Size,
333    discriminants: Vec<u128>,
334    source_info: SourceInfo,
335    new_block: BasicBlock,
336) {
337    // Insert a new target block that is branched to in case of an invalid discriminant.
338    let invalid_discr_block_data = BasicBlockData::new(None, false);
339    let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340    let block_data = &mut basic_blocks[current_block];
341    let discr_place = insert_discr_cast_to_u128(
342        tcx,
343        local_decls,
344        block_data,
345        source_op,
346        discr,
347        op_size,
348        None,
349        source_info,
350    );
351
352    // Mask out the bits of the discriminant type.
353    let mask = discr.size.unsigned_int_max();
354    let discr_masked =
355        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
356    let rvalue = Rvalue::BinaryOp(
357        BinOp::BitAnd,
358        Box::new((
359            Operand::Copy(discr_place),
360            Operand::Constant(Box::new(ConstOperand {
361                span: source_info.span,
362                user_ty: None,
363                const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
364            })),
365        )),
366    );
367    block_data
368        .statements
369        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
370
371    // Branch based on the discriminant value.
372    block_data.terminator = Some(Terminator {
373        source_info,
374        kind: TerminatorKind::SwitchInt {
375            discr: Operand::Copy(discr_masked),
376            targets: SwitchTargets::new(
377                discriminants
378                    .into_iter()
379                    .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
380                invalid_discr_block,
381            ),
382        },
383    });
384
385    // Abort in case of an invalid enum discriminant.
386    basic_blocks[invalid_discr_block].terminator = Some(Terminator {
387        source_info,
388        kind: TerminatorKind::Assert {
389            cond: Operand::Constant(Box::new(ConstOperand {
390                span: source_info.span,
391                user_ty: None,
392                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
393            })),
394            expected: true,
395            target: new_block,
396            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
397            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
398            // We never want to insert an unwind into unsafe code, because unwinding could
399            // make a failing UB check turn into much worse UB when we start unwinding.
400            unwind: UnwindAction::Unreachable,
401        },
402    });
403}
404
405fn insert_uninhabited_enum_check<'tcx>(
406    tcx: TyCtxt<'tcx>,
407    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
408    block_data: &mut BasicBlockData<'tcx>,
409    source_info: SourceInfo,
410    new_block: BasicBlock,
411) {
412    let is_ok: Place<'_> =
413        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
414    block_data.statements.push(Statement::new(
415        source_info,
416        StatementKind::Assign(Box::new((
417            is_ok,
418            Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
419                span: source_info.span,
420                user_ty: None,
421                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
422            }))),
423        ))),
424    ));
425
426    block_data.terminator = Some(Terminator {
427        source_info,
428        kind: TerminatorKind::Assert {
429            cond: Operand::Copy(is_ok),
430            expected: true,
431            target: new_block,
432            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
433                ConstOperand {
434                    span: source_info.span,
435                    user_ty: None,
436                    const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
437                },
438            )))),
439            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
440            // We never want to insert an unwind into unsafe code, because unwinding could
441            // make a failing UB check turn into much worse UB when we start unwinding.
442            unwind: UnwindAction::Unreachable,
443        },
444    });
445}
446
447fn insert_niche_check<'tcx>(
448    tcx: TyCtxt<'tcx>,
449    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
450    block_data: &mut BasicBlockData<'tcx>,
451    source_op: Operand<'tcx>,
452    valid_range: WrappingRange,
453    discr: TyAndSize<'tcx>,
454    op_size: Size,
455    offset: Size,
456    source_info: SourceInfo,
457    new_block: BasicBlock,
458) {
459    let discr = insert_discr_cast_to_u128(
460        tcx,
461        local_decls,
462        block_data,
463        source_op,
464        discr,
465        op_size,
466        Some(offset),
467        source_info,
468    );
469
470    // Compare the discriminant against the valid_range.
471    let start_const = Operand::Constant(Box::new(ConstOperand {
472        span: source_info.span,
473        user_ty: None,
474        const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
475    }));
476    let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
477        span: source_info.span,
478        user_ty: None,
479        const_: Const::Val(
480            ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
481            tcx.types.u128,
482        ),
483    }));
484
485    let discr_diff: Place<'_> =
486        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
487    block_data.statements.push(Statement::new(
488        source_info,
489        StatementKind::Assign(Box::new((
490            discr_diff,
491            Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
492        ))),
493    ));
494
495    let is_ok: Place<'_> =
496        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
497    block_data.statements.push(Statement::new(
498        source_info,
499        StatementKind::Assign(Box::new((
500            is_ok,
501            Rvalue::BinaryOp(
502                // This is a `WrappingRange`, so make sure to get the wrapping right.
503                BinOp::Le,
504                Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
505            ),
506        ))),
507    ));
508
509    block_data.terminator = Some(Terminator {
510        source_info,
511        kind: TerminatorKind::Assert {
512            cond: Operand::Copy(is_ok),
513            expected: true,
514            target: new_block,
515            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
516            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
517            // We never want to insert an unwind into unsafe code, because unwinding could
518            // make a failing UB check turn into much worse UB when we start unwinding.
519            unwind: UnwindAction::Unreachable,
520        },
521    });
522}