1use rustc_hir::lang_items::LangItem;
2use rustc_index::IndexVec;
3use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor};
4use rustc_middle::mir::*;
5use rustc_middle::ty::{self, Ty, TyCtxt};
6use tracing::{debug, trace};
7
8pub(crate) struct PointerCheck<'tcx> {
11 pub(crate) cond: Operand<'tcx>,
12 pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>,
13}
14
15#[derive(Copy, Clone)]
19pub(crate) enum BorrowCheckMode {
20 IncludeBorrows,
21 ExcludeBorrows,
22}
23
24pub(crate) fn check_pointers<'tcx, F>(
44 tcx: TyCtxt<'tcx>,
45 body: &mut Body<'tcx>,
46 excluded_pointees: &[Ty<'tcx>],
47 on_finding: F,
48 borrow_check_mode: BorrowCheckMode,
49) where
50 F: Fn(
51 TyCtxt<'tcx>,
52 Place<'tcx>,
53 Ty<'tcx>,
54 PlaceContext,
55 &mut IndexVec<Local, LocalDecl<'tcx>>,
56 &mut Vec<Statement<'tcx>>,
57 SourceInfo,
58 ) -> PointerCheck<'tcx>,
59{
60 if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
63 return;
64 }
65
66 let typing_env = body.typing_env(tcx);
67 let basic_blocks = body.basic_blocks.as_mut();
68 let local_decls = &mut body.local_decls;
69
70 for block in basic_blocks.indices().rev() {
75 for statement_index in (0..basic_blocks[block].statements.len()).rev() {
76 let location = Location { block, statement_index };
77 let statement = &basic_blocks[block].statements[statement_index];
78 let source_info = statement.source_info;
79
80 let mut finder = PointerFinder::new(
81 tcx,
82 local_decls,
83 typing_env,
84 excluded_pointees,
85 borrow_check_mode,
86 );
87 finder.visit_statement(statement, location);
88
89 for (local, ty, context) in finder.into_found_pointers() {
90 debug!("Inserting check for {:?}", ty);
91 let new_block = split_block(basic_blocks, location);
92
93 let block_data = &mut basic_blocks[block];
97 let pointer_check = on_finding(
98 tcx,
99 local,
100 ty,
101 context,
102 local_decls,
103 &mut block_data.statements,
104 source_info,
105 );
106 block_data.terminator = Some(Terminator {
107 source_info,
108 kind: TerminatorKind::Assert {
109 cond: pointer_check.cond,
110 expected: true,
111 target: new_block,
112 msg: pointer_check.assert_kind,
113 unwind: UnwindAction::Unreachable,
118 },
119 });
120 }
121 }
122 }
123}
124
125struct PointerFinder<'a, 'tcx> {
126 tcx: TyCtxt<'tcx>,
127 local_decls: &'a mut LocalDecls<'tcx>,
128 typing_env: ty::TypingEnv<'tcx>,
129 pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
130 excluded_pointees: &'a [Ty<'tcx>],
131 borrow_check_mode: BorrowCheckMode,
132}
133
134impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
135 fn new(
136 tcx: TyCtxt<'tcx>,
137 local_decls: &'a mut LocalDecls<'tcx>,
138 typing_env: ty::TypingEnv<'tcx>,
139 excluded_pointees: &'a [Ty<'tcx>],
140 borrow_check_mode: BorrowCheckMode,
141 ) -> Self {
142 PointerFinder {
143 tcx,
144 local_decls,
145 typing_env,
146 excluded_pointees,
147 pointers: Vec::new(),
148 borrow_check_mode,
149 }
150 }
151
152 fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> {
153 self.pointers
154 }
155
156 fn should_visit_place(&self, context: PlaceContext) -> bool {
161 match context {
162 PlaceContext::MutatingUse(
163 MutatingUseContext::Store
164 | MutatingUseContext::Call
165 | MutatingUseContext::Yield
166 | MutatingUseContext::Drop,
167 ) => true,
168 PlaceContext::NonMutatingUse(
169 NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
170 ) => true,
171 PlaceContext::MutatingUse(MutatingUseContext::Borrow)
172 | PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) => {
173 matches!(self.borrow_check_mode, BorrowCheckMode::IncludeBorrows)
174 }
175 _ => false,
176 }
177 }
178}
179
180impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
181 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
182 if !self.should_visit_place(context) || !place.is_indirect() {
183 return;
184 }
185
186 let pointer = Place::from(place.local);
189 let pointer_ty = self.local_decls[place.local].ty;
190
191 if !pointer_ty.is_raw_ptr() {
193 trace!("Indirect, but not based on an raw ptr, not checking {:?}", place);
194 return;
195 }
196
197 let pointee_ty =
198 pointer_ty.builtin_deref(true).expect("no builtin_deref for an raw pointer");
199 if !pointee_ty.is_sized(self.tcx, self.typing_env) {
201 trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty);
202 return;
203 }
204
205 let element_ty = match pointee_ty.kind() {
207 ty::Array(ty, _) => *ty,
208 _ => pointee_ty,
209 };
210 if self.excluded_pointees.contains(&element_ty) {
211 trace!("Skipping pointer for type: {:?}", pointee_ty);
212 return;
213 }
214
215 self.pointers.push((pointer, pointee_ty, context));
216
217 self.super_place(place, context, location);
218 }
219}
220
221fn split_block(
222 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
223 location: Location,
224) -> BasicBlock {
225 let block_data = &mut basic_blocks[location.block];
226
227 let new_block = BasicBlockData {
229 statements: block_data.statements.split_off(location.statement_index),
230 terminator: block_data.terminator.take(),
231 is_cleanup: block_data.is_cleanup,
232 };
233
234 basic_blocks.push(new_block)
235}