1use rustc_index::{Idx, IndexVec};
2use rustc_middle::mir::*;
3use rustc_middle::ty::Ty;
4use rustc_span::Span;
5use tracing::debug;
6
7pub(crate) struct MirPatch<'tcx> {
12 term_patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>,
13 new_blocks: Vec<BasicBlockData<'tcx>>,
14 new_statements: Vec<(Location, StatementKind<'tcx>)>,
15 new_locals: Vec<LocalDecl<'tcx>>,
16 resume_block: Option<BasicBlock>,
17 unreachable_cleanup_block: Option<BasicBlock>,
19 unreachable_no_cleanup_block: Option<BasicBlock>,
21 terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
23 body_span: Span,
24 next_local: usize,
25}
26
27impl<'tcx> MirPatch<'tcx> {
28 pub(crate) fn new(body: &Body<'tcx>) -> Self {
30 let mut result = MirPatch {
31 term_patch_map: IndexVec::from_elem(None, &body.basic_blocks),
32 new_blocks: vec![],
33 new_statements: vec![],
34 new_locals: vec![],
35 next_local: body.local_decls.len(),
36 resume_block: None,
37 unreachable_cleanup_block: None,
38 unreachable_no_cleanup_block: None,
39 terminate_block: None,
40 body_span: body.span,
41 };
42
43 for (bb, block) in body.basic_blocks.iter_enumerated() {
44 if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
46 && block.statements.is_empty()
47 {
48 result.resume_block = Some(bb);
49 continue;
50 }
51
52 if matches!(block.terminator().kind, TerminatorKind::Unreachable)
54 && block.statements.is_empty()
55 {
56 if block.is_cleanup {
57 result.unreachable_cleanup_block = Some(bb);
58 } else {
59 result.unreachable_no_cleanup_block = Some(bb);
60 }
61 continue;
62 }
63
64 if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
66 && block.statements.is_empty()
67 {
68 result.terminate_block = Some((bb, reason));
69 continue;
70 }
71 }
72
73 result
74 }
75
76 pub(crate) fn resume_block(&mut self) -> BasicBlock {
77 if let Some(bb) = self.resume_block {
78 return bb;
79 }
80
81 let bb = self.new_block(BasicBlockData::new(
82 Some(Terminator {
83 source_info: SourceInfo::outermost(self.body_span),
84 kind: TerminatorKind::UnwindResume,
85 }),
86 true,
87 ));
88 self.resume_block = Some(bb);
89 bb
90 }
91
92 pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
93 if let Some(bb) = self.unreachable_cleanup_block {
94 return bb;
95 }
96
97 let bb = self.new_block(BasicBlockData::new(
98 Some(Terminator {
99 source_info: SourceInfo::outermost(self.body_span),
100 kind: TerminatorKind::Unreachable,
101 }),
102 true,
103 ));
104 self.unreachable_cleanup_block = Some(bb);
105 bb
106 }
107
108 pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
109 if let Some(bb) = self.unreachable_no_cleanup_block {
110 return bb;
111 }
112
113 let bb = self.new_block(BasicBlockData::new(
114 Some(Terminator {
115 source_info: SourceInfo::outermost(self.body_span),
116 kind: TerminatorKind::Unreachable,
117 }),
118 false,
119 ));
120 self.unreachable_no_cleanup_block = Some(bb);
121 bb
122 }
123
124 pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
125 if let Some((cached_bb, cached_reason)) = self.terminate_block
126 && reason == cached_reason
127 {
128 return cached_bb;
129 }
130
131 let bb = self.new_block(BasicBlockData::new(
132 Some(Terminator {
133 source_info: SourceInfo::outermost(self.body_span),
134 kind: TerminatorKind::UnwindTerminate(reason),
135 }),
136 true,
137 ));
138 self.terminate_block = Some((bb, reason));
139 bb
140 }
141
142 pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
144 self.term_patch_map[bb].is_some()
145 }
146
147 pub(crate) fn block<'a>(
149 &'a self,
150 body: &'a Body<'tcx>,
151 bb: BasicBlock,
152 ) -> &'a BasicBlockData<'tcx> {
153 match bb.index().checked_sub(body.basic_blocks.len()) {
154 Some(new) => &self.new_blocks[new],
155 None => &body[bb],
156 }
157 }
158
159 pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location {
160 let offset = self.block(body, bb).statements.len();
161 Location { block: bb, statement_index: offset }
162 }
163
164 pub(crate) fn new_local_with_info(
166 &mut self,
167 ty: Ty<'tcx>,
168 span: Span,
169 local_info: LocalInfo<'tcx>,
170 ) -> Local {
171 let index = self.next_local;
172 self.next_local += 1;
173 let mut new_decl = LocalDecl::new(ty, span);
174 **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
175 self.new_locals.push(new_decl);
176 Local::new(index)
177 }
178
179 pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
181 let index = self.next_local;
182 self.next_local += 1;
183 self.new_locals.push(LocalDecl::new(ty, span));
184 Local::new(index)
185 }
186
187 pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
189 let local = local.as_usize();
190 assert!(local < self.next_local);
191 let new_local_idx = self.new_locals.len() - (self.next_local - local);
192 self.new_locals[new_local_idx].ty
193 }
194
195 pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
197 let block = self.term_patch_map.next_index();
198 debug!("MirPatch: new_block: {:?}: {:?}", block, data);
199 self.new_blocks.push(data);
200 self.term_patch_map.push(None);
201 block
202 }
203
204 pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
206 assert!(self.term_patch_map[block].is_none());
207 debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
208 self.term_patch_map[block] = Some(new);
209 }
210
211 pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
225 debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
226 self.new_statements.push((loc, stmt));
227 }
228
229 pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
231 self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
232 }
233
234 pub(crate) fn apply(self, body: &mut Body<'tcx>) {
236 debug!(
237 "MirPatch: {:?} new temps, starting from index {}: {:?}",
238 self.new_locals.len(),
239 body.local_decls.len(),
240 self.new_locals
241 );
242 debug!(
243 "MirPatch: {} new blocks, starting from index {}",
244 self.new_blocks.len(),
245 body.basic_blocks.len()
246 );
247 let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
248 body.basic_blocks.as_mut_preserves_cfg()
249 } else {
250 body.basic_blocks.as_mut()
251 };
252 bbs.extend(self.new_blocks);
253 body.local_decls.extend(self.new_locals);
254 for (src, patch) in self.term_patch_map.into_iter_enumerated() {
255 if let Some(patch) = patch {
256 debug!("MirPatch: patching block {:?}", src);
257 bbs[src].terminator_mut().kind = patch;
258 }
259 }
260
261 let mut new_statements = self.new_statements;
262
263 new_statements.sort_by_key(|s| s.0);
266
267 let mut delta = 0;
268 let mut last_bb = START_BLOCK;
269 for (mut loc, stmt) in new_statements {
270 if loc.block != last_bb {
271 delta = 0;
272 last_bb = loc.block;
273 }
274 debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
275 loc.statement_index += delta;
276 let source_info = Self::source_info_for_index(&body[loc.block], loc);
277 body[loc.block]
278 .statements
279 .insert(loc.statement_index, Statement::new(source_info, stmt));
280 delta += 1;
281 }
282 }
283
284 fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
285 match data.statements.get(loc.statement_index) {
286 Some(stmt) => stmt.source_info,
287 None => data.terminator().source_info,
288 }
289 }
290
291 pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
292 let data = self.block(body, loc.block);
293 Self::source_info_for_index(data, loc)
294 }
295}