Skip to content

Commit e4f3431

Browse files
committed
Flatten aggregates into locals.
1 parent b550eab commit e4f3431

21 files changed

+1043
-41
lines changed

compiler/rustc_mir_transform/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ pub mod simplify;
9393
mod simplify_branches;
9494
mod simplify_comparison_integral;
9595
mod simplify_try;
96+
mod sroa;
9697
mod uninhabited_enum_branching;
9798
mod unreachable_prop;
9899

@@ -563,6 +564,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
563564
&remove_zsts::RemoveZsts,
564565
&const_goto::ConstGoto,
565566
&remove_unneeded_drops::RemoveUnneededDrops,
567+
&sroa::ScalarReplacementOfAggregates,
566568
&match_branches::MatchBranchSimplification,
567569
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
568570
&multiple_return_terminators::MultipleReturnTerminators,
+348
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
use crate::MirPass;
2+
use rustc_data_structures::fx::{FxIndexMap, IndexEntry};
3+
use rustc_index::bit_set::BitSet;
4+
use rustc_index::vec::IndexVec;
5+
use rustc_middle::mir::visit::*;
6+
use rustc_middle::mir::*;
7+
use rustc_middle::ty::TyCtxt;
8+
9+
pub struct ScalarReplacementOfAggregates;
10+
11+
impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
12+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
13+
sess.mir_opt_level() >= 4
14+
}
15+
16+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
17+
let escaping = escaping_locals(&*body);
18+
debug!(?escaping);
19+
let replacements = compute_flattening(tcx, body, escaping);
20+
debug!(?replacements);
21+
replace_flattened_locals(tcx, body, replacements);
22+
}
23+
}
24+
25+
/// Identify all locals that are not eligible for SROA.
26+
///
27+
/// There are 3 cases:
28+
/// - the aggegated local is used or passed to other code (function parameters and arguments);
29+
/// - the locals is a union or an enum;
30+
/// - the local's address is taken, and thus the relative addresses of the fields are observable to
31+
/// client code.
32+
fn escaping_locals(body: &Body<'_>) -> BitSet<Local> {
33+
let mut set = BitSet::new_empty(body.local_decls.len());
34+
set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
35+
for (local, decl) in body.local_decls().iter_enumerated() {
36+
if decl.ty.is_union() || decl.ty.is_enum() {
37+
set.insert(local);
38+
}
39+
}
40+
let mut visitor = EscapeVisitor { set };
41+
visitor.visit_body(body);
42+
return visitor.set;
43+
44+
struct EscapeVisitor {
45+
set: BitSet<Local>,
46+
}
47+
48+
impl<'tcx> Visitor<'tcx> for EscapeVisitor {
49+
fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
50+
self.set.insert(local);
51+
}
52+
53+
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
54+
// Mirror the implementation in PreFlattenVisitor.
55+
if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
56+
return;
57+
}
58+
self.super_place(place, context, location);
59+
}
60+
61+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
62+
if let Rvalue::AddressOf(.., place) | Rvalue::Ref(.., place) = rvalue {
63+
if !place.is_indirect() {
64+
// Raw pointers may be used to access anything inside the enclosing place.
65+
self.set.insert(place.local);
66+
return;
67+
}
68+
}
69+
self.super_rvalue(rvalue, location)
70+
}
71+
72+
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
73+
if let StatementKind::StorageLive(..)
74+
| StatementKind::StorageDead(..)
75+
| StatementKind::Deinit(..) = statement.kind
76+
{
77+
// Storage statements are expanded in run_pass.
78+
return;
79+
}
80+
self.super_statement(statement, location)
81+
}
82+
83+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
84+
// Drop implicitly calls `drop_in_place`, which takes a `&mut`.
85+
// This implies that `Drop` implicitly takes the address of the place.
86+
if let TerminatorKind::Drop { place, .. }
87+
| TerminatorKind::DropAndReplace { place, .. } = terminator.kind
88+
{
89+
if !place.is_indirect() {
90+
// Raw pointers may be used to access anything inside the enclosing place.
91+
self.set.insert(place.local);
92+
return;
93+
}
94+
}
95+
self.super_terminator(terminator, location);
96+
}
97+
98+
// We ignore anything that happens in debuginfo, since we expand it using
99+
// `VarDebugInfoContents::Composite`.
100+
fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
101+
}
102+
}
103+
104+
#[derive(Default, Debug)]
105+
struct ReplacementMap<'tcx> {
106+
fields: FxIndexMap<PlaceRef<'tcx>, Local>,
107+
}
108+
109+
/// Compute the replacement of flattened places into locals.
110+
///
111+
/// For each eligible place, we assign a new local to each accessed field.
112+
/// The replacement will be done later in `ReplacementVisitor`.
113+
fn compute_flattening<'tcx>(
114+
tcx: TyCtxt<'tcx>,
115+
body: &mut Body<'tcx>,
116+
escaping: BitSet<Local>,
117+
) -> ReplacementMap<'tcx> {
118+
let mut visitor = PreFlattenVisitor {
119+
tcx,
120+
escaping,
121+
local_decls: &mut body.local_decls,
122+
map: Default::default(),
123+
};
124+
for (block, bbdata) in body.basic_blocks.iter_enumerated() {
125+
visitor.visit_basic_block_data(block, bbdata);
126+
}
127+
return visitor.map;
128+
129+
struct PreFlattenVisitor<'tcx, 'll> {
130+
tcx: TyCtxt<'tcx>,
131+
local_decls: &'ll mut LocalDecls<'tcx>,
132+
escaping: BitSet<Local>,
133+
map: ReplacementMap<'tcx>,
134+
}
135+
136+
impl<'tcx, 'll> PreFlattenVisitor<'tcx, 'll> {
137+
fn create_place(&mut self, place: PlaceRef<'tcx>) {
138+
if self.escaping.contains(place.local) {
139+
return;
140+
}
141+
142+
match self.map.fields.entry(place) {
143+
IndexEntry::Occupied(_) => {}
144+
IndexEntry::Vacant(v) => {
145+
let ty = place.ty(&*self.local_decls, self.tcx).ty;
146+
let local = self.local_decls.push(LocalDecl {
147+
ty,
148+
user_ty: None,
149+
..self.local_decls[place.local].clone()
150+
});
151+
v.insert(local);
152+
}
153+
}
154+
}
155+
}
156+
157+
impl<'tcx, 'll> Visitor<'tcx> for PreFlattenVisitor<'tcx, 'll> {
158+
fn visit_place(&mut self, place: &Place<'tcx>, _: PlaceContext, _: Location) {
159+
if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
160+
let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
161+
self.create_place(pr)
162+
}
163+
}
164+
}
165+
}
166+
167+
/// Perform the replacement computed by `compute_flattening`.
168+
fn replace_flattened_locals<'tcx>(
169+
tcx: TyCtxt<'tcx>,
170+
body: &mut Body<'tcx>,
171+
replacements: ReplacementMap<'tcx>,
172+
) {
173+
let mut all_dead_locals = BitSet::new_empty(body.local_decls.len());
174+
for p in replacements.fields.keys() {
175+
all_dead_locals.insert(p.local);
176+
}
177+
debug!(?all_dead_locals);
178+
if all_dead_locals.is_empty() {
179+
return;
180+
}
181+
182+
let mut fragments = IndexVec::new();
183+
for (k, v) in &replacements.fields {
184+
fragments.ensure_contains_elem(k.local, || Vec::new());
185+
fragments[k.local].push((&k.projection[..], *v));
186+
}
187+
debug!(?fragments);
188+
189+
let mut visitor = ReplacementVisitor {
190+
tcx,
191+
local_decls: &body.local_decls,
192+
replacements,
193+
all_dead_locals,
194+
fragments,
195+
};
196+
for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
197+
visitor.visit_basic_block_data(bb, data);
198+
}
199+
for scope in &mut body.source_scopes {
200+
visitor.visit_source_scope_data(scope);
201+
}
202+
for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
203+
visitor.visit_user_type_annotation(index, annotation);
204+
}
205+
for var_debug_info in &mut body.var_debug_info {
206+
visitor.visit_var_debug_info(var_debug_info);
207+
}
208+
}
209+
210+
struct ReplacementVisitor<'tcx, 'll> {
211+
tcx: TyCtxt<'tcx>,
212+
/// This is only used to compute the type for `VarDebugInfoContents::Composite`.
213+
local_decls: &'ll LocalDecls<'tcx>,
214+
/// Work to do.
215+
replacements: ReplacementMap<'tcx>,
216+
/// This is used to check that we are not leaving references to replaced locals behind.
217+
all_dead_locals: BitSet<Local>,
218+
/// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
219+
/// and deinit statement and debuginfo.
220+
fragments: IndexVec<Local, Vec<(&'tcx [PlaceElem<'tcx>], Local)>>,
221+
}
222+
223+
impl<'tcx, 'll> ReplacementVisitor<'tcx, 'll> {
224+
fn gather_debug_info_fragments(
225+
&self,
226+
place: PlaceRef<'tcx>,
227+
) -> Vec<VarDebugInfoFragment<'tcx>> {
228+
let mut fragments = Vec::new();
229+
let parts = &self.fragments[place.local];
230+
for (proj, replacement_local) in parts {
231+
if proj.starts_with(place.projection) {
232+
fragments.push(VarDebugInfoFragment {
233+
projection: proj[place.projection.len()..].to_vec(),
234+
contents: Place::from(*replacement_local),
235+
});
236+
}
237+
}
238+
fragments
239+
}
240+
241+
fn replace_place(&self, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
242+
if let &[PlaceElem::Field(..), ref rest @ ..] = place.projection {
243+
let pr = PlaceRef { local: place.local, projection: &place.projection[..1] };
244+
let local = self.replacements.fields.get(&pr)?;
245+
Some(Place { local: *local, projection: self.tcx.intern_place_elems(&rest) })
246+
} else {
247+
None
248+
}
249+
}
250+
}
251+
252+
impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
253+
fn tcx(&self) -> TyCtxt<'tcx> {
254+
self.tcx
255+
}
256+
257+
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
258+
if let StatementKind::StorageLive(..)
259+
| StatementKind::StorageDead(..)
260+
| StatementKind::Deinit(..) = statement.kind
261+
{
262+
// Storage statements are expanded in run_pass.
263+
return;
264+
}
265+
self.super_statement(statement, location)
266+
}
267+
268+
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
269+
if let Some(repl) = self.replace_place(place.as_ref()) {
270+
*place = repl
271+
} else {
272+
self.super_place(place, context, location)
273+
}
274+
}
275+
276+
fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
277+
match &mut var_debug_info.value {
278+
VarDebugInfoContents::Place(ref mut place) => {
279+
if let Some(repl) = self.replace_place(place.as_ref()) {
280+
*place = repl;
281+
} else if self.all_dead_locals.contains(place.local) {
282+
let ty = place.ty(self.local_decls, self.tcx).ty;
283+
let fragments = self.gather_debug_info_fragments(place.as_ref());
284+
var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
285+
}
286+
}
287+
VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
288+
let mut new_fragments = Vec::new();
289+
fragments
290+
.drain_filter(|fragment| {
291+
if let Some(repl) = self.replace_place(fragment.contents.as_ref()) {
292+
fragment.contents = repl;
293+
true
294+
} else if self.all_dead_locals.contains(fragment.contents.local) {
295+
let frg = self.gather_debug_info_fragments(fragment.contents.as_ref());
296+
new_fragments.extend(frg.into_iter().map(|mut f| {
297+
f.projection.splice(0..0, fragment.projection.iter().copied());
298+
f
299+
}));
300+
false
301+
} else {
302+
true
303+
}
304+
})
305+
.for_each(drop);
306+
fragments.extend(new_fragments);
307+
}
308+
VarDebugInfoContents::Const(_) => {}
309+
}
310+
}
311+
312+
fn visit_basic_block_data(&mut self, bb: BasicBlock, bbdata: &mut BasicBlockData<'tcx>) {
313+
self.super_basic_block_data(bb, bbdata);
314+
315+
#[derive(Debug)]
316+
enum Stmt {
317+
StorageLive,
318+
StorageDead,
319+
Deinit,
320+
}
321+
322+
bbdata.expand_statements(|stmt| {
323+
let source_info = stmt.source_info;
324+
let (stmt, origin_local) = match &stmt.kind {
325+
StatementKind::StorageLive(l) => (Stmt::StorageLive, *l),
326+
StatementKind::StorageDead(l) => (Stmt::StorageDead, *l),
327+
StatementKind::Deinit(p) if let Some(l) = p.as_local() => (Stmt::Deinit, l),
328+
_ => return None,
329+
};
330+
if !self.all_dead_locals.contains(origin_local) {
331+
return None;
332+
}
333+
let final_locals = self.fragments.get(origin_local)?;
334+
Some(final_locals.iter().map(move |&(_, l)| {
335+
let kind = match stmt {
336+
Stmt::StorageLive => StatementKind::StorageLive(l),
337+
Stmt::StorageDead => StatementKind::StorageDead(l),
338+
Stmt::Deinit => StatementKind::Deinit(Box::new(l.into())),
339+
};
340+
Statement { source_info, kind }
341+
}))
342+
});
343+
}
344+
345+
fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
346+
assert!(!self.all_dead_locals.contains(*local));
347+
}
348+
}

0 commit comments

Comments
 (0)