1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_hir::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::bug;
5use rustc_middle::mir::visit::Visitor;
6use rustc_middle::mir::*;
7use rustc_middle::ty::layout::PrimitiveExt;
8use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
9use rustc_session::Session;
10use tracing::debug;
11
12pub(super) struct CheckEnums;
16
17impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
18 fn is_enabled(&self, sess: &Session) -> bool {
19 sess.ub_checks()
20 }
21
22 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23 if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
26 return;
27 }
28
29 let typing_env = body.typing_env(tcx);
30 let basic_blocks = body.basic_blocks.as_mut();
31 let local_decls = &mut body.local_decls;
32
33 for block in basic_blocks.indices().rev() {
38 for statement_index in (0..basic_blocks[block].statements.len()).rev() {
39 let location = Location { block, statement_index };
40 let statement = &basic_blocks[block].statements[statement_index];
41 let source_info = statement.source_info;
42
43 let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
44 finder.visit_statement(statement, location);
45
46 for check in finder.into_found_enums() {
47 debug!("Inserting enum check");
48 let new_block = split_block(basic_blocks, location);
49
50 match check {
51 EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
52 insert_direct_enum_check(
53 tcx,
54 local_decls,
55 basic_blocks,
56 block,
57 source_op,
58 discr,
59 op_size,
60 valid_discrs,
61 source_info,
62 new_block,
63 )
64 }
65 EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
66 tcx,
67 local_decls,
68 &mut basic_blocks[block],
69 source_info,
70 new_block,
71 ),
72 EnumCheckType::WithNiche {
73 source_op,
74 discr,
75 op_size,
76 offset,
77 valid_range,
78 } => insert_niche_check(
79 tcx,
80 local_decls,
81 &mut basic_blocks[block],
82 source_op,
83 valid_range,
84 discr,
85 op_size,
86 offset,
87 source_info,
88 new_block,
89 ),
90 }
91 }
92 }
93 }
94 }
95
96 fn is_required(&self) -> bool {
97 true
98 }
99}
100
101enum EnumCheckType<'tcx> {
103 Uninhabited,
105 Direct {
108 source_op: Operand<'tcx>,
109 discr: TyAndSize<'tcx>,
110 op_size: Size,
111 valid_discrs: Vec<u128>,
112 },
113 WithNiche {
115 source_op: Operand<'tcx>,
116 discr: TyAndSize<'tcx>,
117 op_size: Size,
118 offset: Size,
119 valid_range: WrappingRange,
120 },
121}
122
123#[derive(Debug, Copy, Clone)]
124struct TyAndSize<'tcx> {
125 pub ty: Ty<'tcx>,
126 pub size: Size,
127}
128
129struct EnumFinder<'a, 'tcx> {
132 tcx: TyCtxt<'tcx>,
133 local_decls: &'a mut LocalDecls<'tcx>,
134 typing_env: TypingEnv<'tcx>,
135 enums: Vec<EnumCheckType<'tcx>>,
136}
137
138impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
139 fn new(
140 tcx: TyCtxt<'tcx>,
141 local_decls: &'a mut LocalDecls<'tcx>,
142 typing_env: TypingEnv<'tcx>,
143 ) -> Self {
144 EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
145 }
146
147 fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
149 self.enums
150 }
151}
152
153impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
154 fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
155 if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
156 let ty::Adt(adt_def, _) = ty.kind() else {
157 return;
158 };
159 if !adt_def.is_enum() {
160 return;
161 }
162
163 let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
164 return;
165 };
166 let Ok(op_layout) = self
167 .tcx
168 .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
169 else {
170 return;
171 };
172
173 match enum_layout.variants {
174 Variants::Empty if op_layout.is_uninhabited() => return,
175 Variants::Empty => {
178 self.enums.push(EnumCheckType::Uninhabited);
181 }
182 Variants::Single { .. } => {}
184 Variants::Multiple {
186 tag_encoding: TagEncoding::Direct,
187 tag: Scalar::Initialized { value, .. },
188 ..
189 } => {
190 let valid_discrs =
191 adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
192
193 let discr =
194 TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
195 self.enums.push(EnumCheckType::Direct {
196 source_op: op.to_copy(),
197 discr,
198 op_size: op_layout.size,
199 valid_discrs,
200 });
201 }
202 Variants::Multiple {
204 tag_encoding: TagEncoding::Niche { .. },
205 tag: Scalar::Initialized { value, valid_range, .. },
206 tag_field,
207 ..
208 } => {
209 let discr =
210 TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
211 self.enums.push(EnumCheckType::WithNiche {
212 source_op: op.to_copy(),
213 discr,
214 op_size: op_layout.size,
215 offset: enum_layout.fields.offset(tag_field.as_usize()),
216 valid_range,
217 });
218 }
219 _ => return,
220 }
221
222 self.super_rvalue(rvalue, location);
223 }
224 }
225}
226
227fn split_block(
228 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
229 location: Location,
230) -> BasicBlock {
231 let block_data = &mut basic_blocks[location.block];
232
233 let new_block = BasicBlockData::new_stmts(
235 block_data.statements.split_off(location.statement_index),
236 block_data.terminator.take(),
237 block_data.is_cleanup,
238 );
239
240 basic_blocks.push(new_block)
241}
242
243fn insert_discr_cast_to_u128<'tcx>(
245 tcx: TyCtxt<'tcx>,
246 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
247 block_data: &mut BasicBlockData<'tcx>,
248 source_op: Operand<'tcx>,
249 discr: TyAndSize<'tcx>,
250 op_size: Size,
251 offset: Option<Size>,
252 source_info: SourceInfo,
253) -> Place<'tcx> {
254 let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
255 match size.bytes() {
256 1 => tcx.types.u8,
257 2 => tcx.types.u16,
258 4 => tcx.types.u32,
259 8 => tcx.types.u64,
260 16 => tcx.types.u128,
261 invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
262 }
263 };
264
265 let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
266 let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
269 let array_len = op_size.bytes();
270 let mu_array_ty = Ty::new_array(tcx, mu, array_len);
271 let mu_array =
272 local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
273 let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
274 block_data
275 .statements
276 .push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
277
278 let offset = offset.unwrap_or(Size::ZERO);
281 let smaller_mu_array = mu_array.project_deeper(
282 &[ProjectionElem::Subslice {
283 from: offset.bytes(),
284 to: offset.bytes() + discr.size.bytes(),
285 from_end: false,
286 }],
287 tcx,
288 );
289
290 (CastKind::Transmute, Operand::Copy(smaller_mu_array))
291 } else {
292 let operand_int_ty = get_ty_for_size(tcx, op_size);
293
294 let op_as_int =
295 local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
296 let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
297 block_data.statements.push(Statement::new(
298 source_info,
299 StatementKind::Assign(Box::new((op_as_int, rvalue))),
300 ));
301
302 (CastKind::IntToInt, Operand::Copy(op_as_int))
303 };
304
305 let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
307 let discr_in_discr_ty =
308 local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
309 block_data.statements.push(Statement::new(
310 source_info,
311 StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
312 ));
313
314 let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
316 let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
317 let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
318 block_data
319 .statements
320 .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
321
322 discr
323}
324
325fn insert_direct_enum_check<'tcx>(
326 tcx: TyCtxt<'tcx>,
327 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
328 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
329 current_block: BasicBlock,
330 source_op: Operand<'tcx>,
331 discr: TyAndSize<'tcx>,
332 op_size: Size,
333 discriminants: Vec<u128>,
334 source_info: SourceInfo,
335 new_block: BasicBlock,
336) {
337 let invalid_discr_block_data = BasicBlockData::new(None, false);
339 let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340 let block_data = &mut basic_blocks[current_block];
341 let discr_place = insert_discr_cast_to_u128(
342 tcx,
343 local_decls,
344 block_data,
345 source_op,
346 discr,
347 op_size,
348 None,
349 source_info,
350 );
351
352 let mask = discr.size.unsigned_int_max();
354 let discr_masked =
355 local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
356 let rvalue = Rvalue::BinaryOp(
357 BinOp::BitAnd,
358 Box::new((
359 Operand::Copy(discr_place),
360 Operand::Constant(Box::new(ConstOperand {
361 span: source_info.span,
362 user_ty: None,
363 const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
364 })),
365 )),
366 );
367 block_data
368 .statements
369 .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
370
371 block_data.terminator = Some(Terminator {
373 source_info,
374 kind: TerminatorKind::SwitchInt {
375 discr: Operand::Copy(discr_masked),
376 targets: SwitchTargets::new(
377 discriminants
378 .into_iter()
379 .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
380 invalid_discr_block,
381 ),
382 },
383 });
384
385 basic_blocks[invalid_discr_block].terminator = Some(Terminator {
387 source_info,
388 kind: TerminatorKind::Assert {
389 cond: Operand::Constant(Box::new(ConstOperand {
390 span: source_info.span,
391 user_ty: None,
392 const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
393 })),
394 expected: true,
395 target: new_block,
396 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
397 unwind: UnwindAction::Unreachable,
401 },
402 });
403}
404
405fn insert_uninhabited_enum_check<'tcx>(
406 tcx: TyCtxt<'tcx>,
407 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
408 block_data: &mut BasicBlockData<'tcx>,
409 source_info: SourceInfo,
410 new_block: BasicBlock,
411) {
412 let is_ok: Place<'_> =
413 local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
414 block_data.statements.push(Statement::new(
415 source_info,
416 StatementKind::Assign(Box::new((
417 is_ok,
418 Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
419 span: source_info.span,
420 user_ty: None,
421 const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
422 }))),
423 ))),
424 ));
425
426 block_data.terminator = Some(Terminator {
427 source_info,
428 kind: TerminatorKind::Assert {
429 cond: Operand::Copy(is_ok),
430 expected: true,
431 target: new_block,
432 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
433 ConstOperand {
434 span: source_info.span,
435 user_ty: None,
436 const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
437 },
438 )))),
439 unwind: UnwindAction::Unreachable,
443 },
444 });
445}
446
447fn insert_niche_check<'tcx>(
448 tcx: TyCtxt<'tcx>,
449 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
450 block_data: &mut BasicBlockData<'tcx>,
451 source_op: Operand<'tcx>,
452 valid_range: WrappingRange,
453 discr: TyAndSize<'tcx>,
454 op_size: Size,
455 offset: Size,
456 source_info: SourceInfo,
457 new_block: BasicBlock,
458) {
459 let discr = insert_discr_cast_to_u128(
460 tcx,
461 local_decls,
462 block_data,
463 source_op,
464 discr,
465 op_size,
466 Some(offset),
467 source_info,
468 );
469
470 let start_const = Operand::Constant(Box::new(ConstOperand {
472 span: source_info.span,
473 user_ty: None,
474 const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
475 }));
476 let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
477 span: source_info.span,
478 user_ty: None,
479 const_: Const::Val(
480 ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
481 tcx.types.u128,
482 ),
483 }));
484
485 let discr_diff: Place<'_> =
486 local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
487 block_data.statements.push(Statement::new(
488 source_info,
489 StatementKind::Assign(Box::new((
490 discr_diff,
491 Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
492 ))),
493 ));
494
495 let is_ok: Place<'_> =
496 local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
497 block_data.statements.push(Statement::new(
498 source_info,
499 StatementKind::Assign(Box::new((
500 is_ok,
501 Rvalue::BinaryOp(
502 BinOp::Le,
504 Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
505 ),
506 ))),
507 ));
508
509 block_data.terminator = Some(Terminator {
510 source_info,
511 kind: TerminatorKind::Assert {
512 cond: Operand::Copy(is_ok),
513 expected: true,
514 target: new_block,
515 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
516 unwind: UnwindAction::Unreachable,
520 },
521 });
522}