Skip to content

Commit e82f68e

Browse files
eddybFirestar99
authored andcommitted
WIP: (TODO: finish bottom-up cleanups) bottom-up inlining
1 parent abaac31 commit e82f68e

File tree

2 files changed

+107
-82
lines changed

2 files changed

+107
-82
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 96 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ use rustc_session::Session;
1717
use smallvec::SmallVec;
1818
use std::mem;
1919

20-
type FunctionMap = FxHashMap<Word, Function>;
21-
2220
// FIXME(eddyb) this is a bit silly, but this keeps being repeated everywhere.
2321
fn next_id(header: &mut ModuleHeader) -> Word {
2422
let result = header.bound;
@@ -30,6 +28,9 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3028
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
3129
deny_recursion_in_module(sess, module)?;
3230

31+
// Compute the call-graph that will drive (inside-out, aka bottom-up) inlining.
32+
let (call_graph, func_id_to_idx) = CallGraph::collect_with_func_id_to_idx(module);
33+
3334
let custom_ext_inst_set_import = module
3435
.ext_inst_imports
3536
.iter()
@@ -39,62 +40,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3940
})
4041
.map(|inst| inst.result_id.unwrap());
4142

42-
// HACK(eddyb) compute the set of functions that may `Abort` *transitively*,
43-
// which is only needed because of how we inline (sometimes it's outside-in,
44-
// aka top-down, instead of always being inside-out, aka bottom-up).
45-
//
46-
// (inlining is needed in the first place because our custom `Abort`
47-
// instructions get lowered to a simple `OpReturn` in entry-points, but
48-
// that requires that they get inlined all the way up to the entry-points)
49-
let functions_that_may_abort = custom_ext_inst_set_import
50-
.map(|custom_ext_inst_set_import| {
51-
let mut may_abort_by_id = FxHashSet::default();
52-
53-
// FIXME(eddyb) use this `CallGraph` abstraction more during inlining.
54-
let call_graph = CallGraph::collect(module);
55-
for func_idx in call_graph.post_order() {
56-
let func_id = module.functions[func_idx].def_id().unwrap();
57-
58-
let any_callee_may_abort = call_graph.callees[func_idx].iter().any(|&callee_idx| {
59-
may_abort_by_id.contains(&module.functions[callee_idx].def_id().unwrap())
60-
});
61-
if any_callee_may_abort {
62-
may_abort_by_id.insert(func_id);
63-
continue;
64-
}
65-
66-
let may_abort_directly = module.functions[func_idx].blocks.iter().any(|block| {
67-
match &block.instructions[..] {
68-
[.., last_normal_inst, terminator_inst]
69-
if last_normal_inst.class.opcode == Op::ExtInst
70-
&& last_normal_inst.operands[0].unwrap_id_ref()
71-
== custom_ext_inst_set_import
72-
&& CustomOp::decode_from_ext_inst(last_normal_inst)
73-
== CustomOp::Abort =>
74-
{
75-
assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
76-
true
77-
}
78-
79-
_ => false,
80-
}
81-
});
82-
if may_abort_directly {
83-
may_abort_by_id.insert(func_id);
84-
}
85-
}
86-
87-
may_abort_by_id
88-
})
89-
.unwrap_or_default();
90-
91-
let functions = module
92-
.functions
93-
.iter()
94-
.map(|f| (f.def_id().unwrap(), f.clone()))
95-
.collect();
96-
let legal_globals = LegalGlobal::gather_from_module(module);
97-
43+
/*
9844
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
9945
// inlines in functions that will get inlined)
10046
let mut dropped_ids = FxHashSet::default();
@@ -123,6 +69,9 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
12369
));
12470
}
12571
}
72+
*/
73+
74+
let legal_globals = LegalGlobal::gather_from_module(module);
12675

12776
let header = module.header.as_mut().unwrap();
12877
// FIXME(eddyb) clippy false positive (separate `map` required for borrowck).
@@ -149,6 +98,8 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
14998
id
15099
}),
151100

101+
func_id_to_idx,
102+
152103
id_to_name: module
153104
.debug_names
154105
.iter()
@@ -168,22 +119,61 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
168119
annotations: &mut module.annotations,
169120
types_global_values: &mut module.types_global_values,
170121

171-
functions: &functions,
172-
legal_globals: &legal_globals,
173-
functions_that_may_abort: &functions_that_may_abort,
122+
legal_globals,
123+
124+
// NOTE(eddyb) this is needed because our custom `Abort` instructions get
125+
// lowered to a simple `OpReturn` in entry-points, but that requires that
126+
// they get inlined all the way up to the entry-points in the first place.
127+
functions_that_may_abort: module
128+
.functions
129+
.iter()
130+
.filter_map(|func| {
131+
let custom_ext_inst_set_import = custom_ext_inst_set_import?;
132+
func.blocks
133+
.iter()
134+
.any(|block| match &block.instructions[..] {
135+
[.., last_normal_inst, terminator_inst]
136+
if last_normal_inst.class.opcode == Op::ExtInst
137+
&& last_normal_inst.operands[0].unwrap_id_ref()
138+
== custom_ext_inst_set_import
139+
&& CustomOp::decode_from_ext_inst(last_normal_inst)
140+
== CustomOp::Abort =>
141+
{
142+
assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
143+
true
144+
}
145+
146+
_ => false,
147+
})
148+
.then_some(func.def_id().unwrap())
149+
})
150+
.collect(),
174151
};
175-
for function in &mut module.functions {
176-
inliner.inline_fn(function);
177-
fuse_trivial_branches(function);
152+
153+
let mut functions: Vec<_> = mem::take(&mut module.functions)
154+
.into_iter()
155+
.map(Ok)
156+
.collect();
157+
158+
// Inline functions in post-order (aka inside-out aka bottom-out) - that is,
159+
// callees are processed before their callers, to avoid duplicating work.
160+
for func_idx in call_graph.post_order() {
161+
let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap();
162+
inliner.inline_fn(&mut function, &functions);
163+
fuse_trivial_branches(&mut function);
164+
functions[func_idx] = Ok(function);
178165
}
179166

167+
module.functions = functions.into_iter().map(|func| func.unwrap()).collect();
168+
169+
/*
180170
// Drop OpName etc. for inlined functions
181171
module.debug_names.retain(|inst| {
182172
!inst
183173
.operands
184174
.iter()
185175
.any(|op| op.id_ref_any().is_some_and(|id| dropped_ids.contains(&id)))
186-
});
176+
});*/
187177

188178
Ok(())
189179
}
@@ -451,19 +441,27 @@ fn should_inline(
451441
Ok(callee_control.contains(FunctionControl::INLINE))
452442
}
453443

444+
/// Helper error type for `Inliner`'s `functions` field, indicating a `Function`
445+
/// was taken out of its slot because it's being inlined.
446+
#[derive(Debug)]
447+
struct FuncIsBeingInlined;
448+
454449
// Steps:
455450
// Move OpVariable decls
456451
// Rewrite return
457452
// Renumber IDs
458453
// Insert blocks
459454

460-
struct Inliner<'m, 'map> {
455+
struct Inliner<'m> {
461456
/// ID of `OpExtInstImport` for our custom "extended instruction set"
462457
/// (see `crate::custom_insts` for more details).
463458
custom_ext_inst_set_import: Word,
464459

465460
op_type_void_id: Word,
466461

462+
/// Map from each function's ID to its index in `functions`.
463+
func_id_to_idx: FxHashMap<Word, usize>,
464+
467465
/// Pre-collected `OpName`s, that can be used to find any function's name
468466
/// during inlining (to be able to generate debuginfo that uses names).
469467
id_to_name: FxHashMap<Word, &'m str>,
@@ -480,13 +478,12 @@ struct Inliner<'m, 'map> {
480478
annotations: &'m mut Vec<Instruction>,
481479
types_global_values: &'m mut Vec<Instruction>,
482480

483-
functions: &'map FunctionMap,
484-
legal_globals: &'map FxHashMap<Word, LegalGlobal>,
485-
functions_that_may_abort: &'map FxHashSet<Word>,
481+
legal_globals: FxHashMap<Word, LegalGlobal>,
482+
functions_that_may_abort: FxHashSet<Word>,
486483
// rewrite_rules: FxHashMap<Word, Word>,
487484
}
488485

489-
impl Inliner<'_, '_> {
486+
impl Inliner<'_> {
490487
fn id(&mut self) -> Word {
491488
next_id(self.header)
492489
}
@@ -531,19 +528,29 @@ impl Inliner<'_, '_> {
531528
inst_id
532529
}
533530

534-
fn inline_fn(&mut self, function: &mut Function) {
531+
fn inline_fn(
532+
&mut self,
533+
function: &mut Function,
534+
functions: &[Result<Function, FuncIsBeingInlined>],
535+
) {
535536
let mut block_idx = 0;
536537
while block_idx < function.blocks.len() {
537538
// If we successfully inlined a block, then repeat processing on the same block, in
538539
// case the newly inlined block has more inlined calls.
539540
// TODO: This is quadratic
540-
if !self.inline_block(function, block_idx) {
541+
if !self.inline_block(function, block_idx, functions) {
542+
// TODO(eddyb) skip past the inlined callee without rescanning it.
541543
block_idx += 1;
542544
}
543545
}
544546
}
545547

546-
fn inline_block(&mut self, caller: &mut Function, block_idx: usize) -> bool {
548+
fn inline_block(
549+
&mut self,
550+
caller: &mut Function,
551+
block_idx: usize,
552+
functions: &[Result<Function, FuncIsBeingInlined>],
553+
) -> bool {
547554
// Find the first inlined OpFunctionCall
548555
let call = caller.blocks[block_idx]
549556
.instructions
@@ -554,8 +561,8 @@ impl Inliner<'_, '_> {
554561
(
555562
index,
556563
inst,
557-
self.functions
558-
.get(&inst.operands[0].id_ref_any().unwrap())
564+
functions[self.func_id_to_idx[&inst.operands[0].id_ref_any().unwrap()]]
565+
.as_ref()
559566
.unwrap(),
560567
)
561568
})
@@ -565,8 +572,8 @@ impl Inliner<'_, '_> {
565572
call_inst: inst,
566573
};
567574
match should_inline(
568-
self.legal_globals,
569-
self.functions_that_may_abort,
575+
&self.legal_globals,
576+
&self.functions_that_may_abort,
570577
f,
571578
Some(call_site),
572579
) {
@@ -578,6 +585,16 @@ impl Inliner<'_, '_> {
578585
None => return false,
579586
Some(call) => call,
580587
};
588+
589+
// Propagate "may abort" from callee to caller (i.e. as aborts get inlined).
590+
if self
591+
.functions_that_may_abort
592+
.contains(&callee.def_id().unwrap())
593+
{
594+
self.functions_that_may_abort
595+
.insert(caller.def_id().unwrap());
596+
}
597+
581598
let call_result_type = {
582599
let ty = call_inst.result_type.unwrap();
583600
if ty == self.op_type_void_id {
@@ -589,6 +606,7 @@ impl Inliner<'_, '_> {
589606
let call_result_id = call_inst.result_id.unwrap();
590607

591608
// Get the debuginfo instructions that apply to the call.
609+
// TODO(eddyb) only one instruction should be necessary here w/ bottom-up.
592610
let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
593611
let call_debug_insts = caller.blocks[block_idx].instructions[..call_index]
594612
.iter()
@@ -863,6 +881,7 @@ impl Inliner<'_, '_> {
863881
..
864882
} = *self;
865883

884+
// TODO(eddyb) kill this as it shouldn't be needed for bottom-up inline.
866885
// HACK(eddyb) this is terrible, but we have to deal with it because of
867886
// how this inliner is outside-in, instead of inside-out, meaning that
868887
// context builds up "outside" of the callee blocks, inside the caller.

crates/rustc_codegen_spirv/src/linker/ipo.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
use indexmap::IndexSet;
66
use rspirv::dr::Module;
7-
use rspirv::spirv::Op;
7+
use rspirv::spirv::{Op, Word};
88
use rustc_data_structures::fx::FxHashMap;
99

1010
// FIXME(eddyb) use newtyped indices and `IndexVec`.
@@ -19,6 +19,9 @@ pub struct CallGraph {
1919

2020
impl CallGraph {
2121
pub fn collect(module: &Module) -> Self {
22+
Self::collect_with_func_id_to_idx(module).0
23+
}
24+
pub fn collect_with_func_id_to_idx(module: &Module) -> (Self, FxHashMap<Word, FuncIdx>) {
2225
let func_id_to_idx: FxHashMap<_, _> = module
2326
.functions
2427
.iter()
@@ -51,10 +54,13 @@ impl CallGraph {
5154
.collect()
5255
})
5356
.collect();
54-
Self {
55-
entry_points,
56-
callees,
57-
}
57+
(
58+
Self {
59+
entry_points,
60+
callees,
61+
},
62+
func_id_to_idx,
63+
)
5864
}
5965

6066
/// Order functions using a post-order traversal, i.e. callees before callers.

0 commit comments

Comments
 (0)