Skip to content

Commit 37e77e9

Browse files
committed
Auto merge of rust-lang#136745 - FractalFir:master, r=<try>
[perf experiment] A MIR pass dedicated to optimizing common iterators **This PR is a perf experiment, and is not meant to be accepted. I am creating it to request a perf run** # Motivation Currently, many commonly used iterators don't get inlined during MIR optimization: this leads to increased ammount of LLVM-IR. Since those iterators are also generic, they are unlikely to be inlined anytime soon. # Optimizing slice iteraotrs This PR adds an experimental pass which replaces 2 commonly used iterators(`std::slice::Iter` and `std::slice::IterMut`) with inline implementations. Should this pass show potential for performance gains, I will work on an improved version, which will also handle other common iterators from `core`(eg. `Range`, `Enumerate`). A proper implementation will require other, bigger changes(e.g. *maybe* marking certain iterators as lang items for quicker lookup). Because of that, I am asking for a perf run, to see if that effort will be worth it.
2 parents 8ad2c97 + 1d4d571 commit 37e77e9

File tree

4 files changed

+375
-1
lines changed

4 files changed

+375
-1
lines changed

compiler/rustc_mir_transform/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ declare_passes! {
171171
mod required_consts : RequiredConstsVisitor;
172172
mod post_analysis_normalize : PostAnalysisNormalize;
173173
mod sanity_check : SanityCheck;
174+
mod streamline_iter : StreamlineIter;
174175
// This pass is public to allow external drivers to perform MIR cleanup
175176
pub mod simplify :
176177
SimplifyCfg {
@@ -646,6 +647,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
646647
// Add some UB checks before any UB gets optimized away.
647648
&check_alignment::CheckAlignment,
648649
&check_null::CheckNull,
650+
// Done as early as possible: this is a cheap(?) pass which reduces the ammount of MIR by a fair bit.
651+
&streamline_iter::StreamlineIter::new(tcx),
649652
// Before inlining: trim down MIR with passes to reduce inlining work.
650653

651654
// Has to be done before inlining, otherwise actual call will be almost always inlined.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
//! Replaces calls to `Iter::next` with small, specialized MIR implementations, for some common iterators.
2+
use rustc_abi::{FieldIdx, VariantIdx};
3+
use rustc_index::IndexVec;
4+
use rustc_middle::mir::interpret::Scalar;
5+
use rustc_middle::mir::{SourceInfo, *};
6+
use rustc_middle::ty::{self, AdtDef, AdtKind, GenericArgs, Ty, TyCtxt};
7+
use rustc_span::Span;
8+
use rustc_type_ir::inherent::*;
9+
use tracing::trace;
10+
11+
use crate::hir::def_id::{CrateNum, DefId};
12+
13+
pub(super) enum StreamlineIter {
14+
Working { core: CrateNum, iter_next: DefId },
15+
Disabled,
16+
}
17+
impl StreamlineIter {
18+
pub(crate) fn new(tcx: TyCtxt<'_>) -> Self {
19+
let Some(iter_next) = tcx.lang_items().next_fn() else {
20+
return Self::Disabled;
21+
};
22+
let core = iter_next.krate;
23+
Self::Working { core, iter_next }
24+
}
25+
}
26+
impl<'tcx> crate::MirPass<'tcx> for StreamlineIter {
27+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
28+
sess.mir_opt_level() > 1 && (matches!(self, StreamlineIter::Working { .. }))
29+
}
30+
// Temporary allow for dev purposes
31+
#[allow(unused_variables, unused_mut, unreachable_code)]
32+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
33+
trace!("Running StreamlineIter on {:?}", body.source);
34+
let Self::Working { core, iter_next } = self else {
35+
return;
36+
};
37+
let mut bbs = body.basic_blocks.as_mut_preserves_cfg();
38+
let locals = &mut body.local_decls;
39+
// If any optimizations were pefromed, invalidate the cache.
40+
let mut cfg_invalid = false;
41+
42+
// 1st. Go trough all terminators, find calls.
43+
for bid in (0..(bbs.len())).into_iter().map(BasicBlock::from_usize) {
44+
let mut bb = &bbs[bid];
45+
// Check if this is the call to std::slice::Iter::next OR std::slice::IterMut::next
46+
let Some(InlineSliceNextCandidate {
47+
iter_place,
48+
iter_adt,
49+
iter_args,
50+
fn_span,
51+
source_info,
52+
destination,
53+
target,
54+
}) = terminator_iter_next(&bb.terminator, *iter_next, *core, tcx)
55+
else {
56+
continue;
57+
};
58+
// Find the relevant field:
59+
let (notnull_idx, notnull_ty) = iter_adt
60+
.variant(VariantIdx::ZERO)
61+
.fields
62+
.iter()
63+
.enumerate()
64+
.map(|(idx, field)| (FieldIdx::from_usize(idx), field.ty(tcx, iter_args)))
65+
.filter(|(idx, ty)| match ty.kind() {
66+
ty::Adt(adt, _) => !adt.is_phantom_data(),
67+
_ => false,
68+
})
69+
.next()
70+
.unwrap();
71+
let iter_place = tcx.mk_place_deref(iter_place);
72+
let ptr_nonull = tcx.mk_place_field(iter_place, notnull_idx, notnull_ty);
73+
let ty::Adt(non_null_adt, on_null_arg) = notnull_ty.kind() else {
74+
continue;
75+
};
76+
let (inner_idx, inner_t) = non_null_adt
77+
.variant(VariantIdx::ZERO)
78+
.fields
79+
.iter()
80+
.enumerate()
81+
.map(|(idx, field)| (FieldIdx::from_usize(idx), field.ty(tcx, on_null_arg)))
82+
.filter(|(idx, ty)| match ty.kind() {
83+
ty::RawPtr(_, _) => true,
84+
_ => false,
85+
})
86+
.next()
87+
.unwrap();
88+
let pointer = tcx.mk_place_field(ptr_nonull, inner_idx, inner_t);
89+
// Increment pointer
90+
let val = Operand::Copy(pointer);
91+
let one = Operand::const_from_scalar(
92+
tcx,
93+
tcx.types.usize,
94+
Scalar::from_target_usize(1, &tcx),
95+
fn_span,
96+
);
97+
let offset = Rvalue::BinaryOp(BinOp::Offset, Box::new((val, one)));
98+
let incr =
99+
Statement { kind: StatementKind::Assign(Box::new((pointer, offset))), source_info };
100+
// Allocate the check & cast_end local:
101+
let check = locals.push(LocalDecl::new(tcx.types.bool, fn_span));
102+
// Bounds check
103+
let (idx, ty) = iter_adt
104+
.variant(VariantIdx::ZERO)
105+
.fields
106+
.iter()
107+
.enumerate()
108+
.map(|(idx, field)| (FieldIdx::from_usize(idx), field.ty(tcx, iter_args)))
109+
.filter(|(idx, ty)| match ty.kind() {
110+
ty::RawPtr(_, _) => true,
111+
_ => false,
112+
})
113+
.next()
114+
.unwrap();
115+
116+
let end_ptr = tcx.mk_place_field(iter_place, idx, ty);
117+
let end_ptr = Operand::Copy(end_ptr);
118+
let ptr = Operand::Copy(pointer);
119+
let pointer_ty = pointer.ty(locals, tcx).ty;
120+
let end_ptr_after_cast = locals.push(LocalDecl::new(pointer_ty, fn_span));
121+
let cast_end_ptr = Rvalue::Cast(CastKind::PtrToPtr, end_ptr, pointer_ty);
122+
let ptr_cast = Statement {
123+
kind: StatementKind::Assign(Box::new((end_ptr_after_cast.into(), cast_end_ptr))),
124+
source_info,
125+
};
126+
127+
let is_empty = Rvalue::BinaryOp(
128+
BinOp::Eq,
129+
Box::new((ptr, Operand::Copy(end_ptr_after_cast.into()))),
130+
);
131+
let check_iter_empty = Statement {
132+
kind: StatementKind::Assign(Box::new((check.into(), is_empty))),
133+
source_info,
134+
};
135+
136+
// Create the Some and None blocks
137+
let rejoin = Terminator { kind: TerminatorKind::Goto { target }, source_info };
138+
let mut some_block = BasicBlockData::new(Some(rejoin.clone()), false);
139+
let mut none_block = BasicBlockData::new(Some(rejoin), false);
140+
// Create the None value
141+
let dst_ty = destination.ty(locals, tcx);
142+
let ty::Adt(option_adt, option_gargs) = dst_ty.ty.kind() else {
143+
continue;
144+
};
145+
let none_val = Rvalue::Aggregate(
146+
Box::new(AggregateKind::Adt(
147+
option_adt.did(),
148+
VariantIdx::ZERO,
149+
option_gargs,
150+
None,
151+
None,
152+
)),
153+
IndexVec::new(),
154+
);
155+
let set_none = Statement {
156+
kind: StatementKind::Assign(Box::new((destination, none_val))),
157+
source_info,
158+
};
159+
none_block.statements.push(set_none);
160+
// Cast the pointer to a refernece, preserving lifetimes.
161+
let ref_ty = option_gargs[0].expect_ty();
162+
let ref_local = locals.push(LocalDecl::new(ref_ty, fn_span));
163+
164+
let ty::Ref(region, _, muta) = ref_ty.kind() else {
165+
continue;
166+
};
167+
let pointer_local = locals.push(LocalDecl::new(pointer_ty, fn_span));
168+
let pointer_assign = Rvalue::Use(Operand::Copy(pointer));
169+
let pointer_assign = Statement {
170+
kind: StatementKind::Assign(Box::new((pointer_local.into(), pointer_assign))),
171+
source_info,
172+
};
173+
let borrow = if *muta == Mutability::Not {
174+
BorrowKind::Shared
175+
} else {
176+
BorrowKind::Mut { kind: MutBorrowKind::Default }
177+
};
178+
let rf = Rvalue::Ref(*region, borrow, tcx.mk_place_deref(pointer_local.into()));
179+
let rf = Statement {
180+
kind: StatementKind::Assign(Box::new((ref_local.into(), rf))),
181+
source_info,
182+
};
183+
let some_val = Rvalue::Aggregate(
184+
Box::new(AggregateKind::Adt(
185+
option_adt.did(),
186+
VariantIdx::from_usize(1),
187+
option_gargs,
188+
None,
189+
None,
190+
)),
191+
[Operand::Move(ref_local.into())].into(),
192+
);
193+
let set_some = Statement {
194+
kind: StatementKind::Assign(Box::new((destination, some_val))),
195+
source_info,
196+
};
197+
some_block.statements.push(pointer_assign);
198+
some_block.statements.push(rf);
199+
some_block.statements.push(incr);
200+
some_block.statements.push(set_some);
201+
202+
// Get the new blocks in place - this invalidates caches!
203+
cfg_invalid = true;
204+
let some_bb = bbs.push(some_block);
205+
let none_bb = bbs.push(none_block);
206+
207+
// Change the original block.
208+
let mut bb = &mut bbs[bid];
209+
bb.terminator = Some(Terminator {
210+
kind: TerminatorKind::SwitchInt {
211+
discr: Operand::Move(check.into()),
212+
targets: SwitchTargets::new(std::iter::once((0, some_bb)), none_bb),
213+
},
214+
source_info,
215+
});
216+
bb.statements.push(ptr_cast);
217+
bb.statements.push(check_iter_empty);
218+
}
219+
if cfg_invalid {
220+
body.basic_blocks.invalidate_cfg_cache();
221+
}
222+
}
223+
224+
fn is_required(&self) -> bool {
225+
true
226+
}
227+
}
228+
fn not_zst<'tcx>(t: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> bool {
229+
match t.kind() {
230+
ty::Uint(_)
231+
| ty::Int(_)
232+
| ty::Bool
233+
| ty::Float(_)
234+
| ty::Char
235+
| ty::Ref(..)
236+
| ty::RawPtr(..)
237+
| ty::FnPtr(..) => true,
238+
ty::Tuple(elements) => elements.iter().any(|ty| not_zst(ty, tcx)),
239+
ty::Array(elem, count) if count.try_to_target_usize(tcx).is_some_and(|count| count > 0) => {
240+
not_zst(*elem, tcx)
241+
}
242+
ty::Array(_, _) => false,
243+
ty::Never | ty::FnDef(..) => false,
244+
ty::Adt(def, args) => match def.adt_kind() {
245+
AdtKind::Enum => def.variants().len() > 1,
246+
AdtKind::Struct | AdtKind::Union => def
247+
.variant(VariantIdx::ZERO)
248+
.fields
249+
.iter()
250+
.any(|field| not_zst(field.ty(tcx, args), tcx)),
251+
},
252+
// Generic's, can't determine if they are not-zst's at compile time.
253+
ty::Param(..) | ty::Alias(..) | ty::Bound(..) => false,
254+
// Those should not occur here, but I still handle them just in case.
255+
ty::Str | ty::Slice(..) | ty::Foreign(_) | ty::Dynamic(..) => false,
256+
ty::Pat(..) | ty::UnsafeBinder(..) | ty::Infer(..) | ty::Placeholder(_) | ty::Error(_) => {
257+
false
258+
}
259+
// There are ways to check if those are ZSTs, but this is not worth it ATM.
260+
ty::Closure(..)
261+
| ty::CoroutineClosure(..)
262+
| ty::Coroutine(..)
263+
| ty::CoroutineWitness(..) => false,
264+
}
265+
}
266+
//-Copt-level=3 -Zmir-opt-level=3 --emit=llvm-ir -C debug-assertions=no
267+
struct InlineSliceNextCandidate<'tcx> {
268+
iter_place: Place<'tcx>,
269+
iter_adt: AdtDef<'tcx>,
270+
iter_args: &'tcx GenericArgs<'tcx>,
271+
fn_span: Span,
272+
source_info: SourceInfo,
273+
destination: Place<'tcx>,
274+
target: BasicBlock,
275+
}
276+
/// This function checks if this is a call to `std::slice::Iter::next` OR `std::slice::IterMut::next`.
277+
/// Currently, it uses a bunch of ulgy things to do so, but if those iterators become lang items, then
278+
/// this could be replaced by a simple DefID check.
279+
#[allow(unreachable_code, unused_variables)]
280+
fn terminator_iter_next<'tcx>(
281+
terminator: &Option<Terminator<'tcx>>,
282+
iter_next: DefId,
283+
core: CrateNum,
284+
tcx: TyCtxt<'tcx>,
285+
) -> Option<InlineSliceNextCandidate<'tcx>> {
286+
use rustc_type_ir::inherent::*;
287+
let Terminator { kind, source_info } = terminator.as_ref()?;
288+
let TerminatorKind::Call {
289+
ref func,
290+
ref args,
291+
destination,
292+
target,
293+
unwind: _,
294+
call_source: _,
295+
fn_span,
296+
} = kind
297+
else {
298+
return None;
299+
};
300+
// 2. Check that the `func` of the call is known.
301+
let func = func.constant()?;
302+
// 3. Check that the `func` is FnDef
303+
let ty::FnDef(defid, generic_args) = func.ty().kind() else {
304+
return None;
305+
};
306+
// 4. Check that this is Iter::next
307+
if *defid != iter_next {
308+
return None;
309+
}
310+
// 5. Extract parts of the iterator
311+
let iter_ty = generic_args[0].expect_ty();
312+
let ty::Adt(iter_adt, iter_args) = iter_ty.kind() else {
313+
return None;
314+
};
315+
if iter_adt.did().krate != core {
316+
return None;
317+
}
318+
// 6. Check its argument count - this is a short, cheap check
319+
if iter_args.len() != 2 {
320+
return None;
321+
}
322+
// 7. Check that the first arg is a lifetime
323+
if iter_args[0].as_region().is_none() {
324+
return None;
325+
}
326+
// 8. Check that this ADT is a struct, and has 3 fields.
327+
if !iter_adt.is_struct() {
328+
return None;
329+
}
330+
if iter_adt.all_fields().count() != 3 {
331+
return None;
332+
}
333+
// Check that it has a *const T field.
334+
if !iter_adt.all_field_tys(tcx).skip_binder().into_iter().any(|ty| match ty.kind() {
335+
ty::RawPtr(_, _) => true,
336+
_ => false,
337+
}) {
338+
return None;
339+
}
340+
// 7. Check that the name of this ADT is `slice::iter::Iter`. This is a janky way to check if this is the iterator we are interested in.
341+
let name = format!("{:?}", iter_adt.did());
342+
if !name.as_str().contains("slice::iter::Iter") {
343+
return None;
344+
}
345+
// We now know this is a slice iterator - so we can optimize it !
346+
// Check if we know if this is not a `zst`
347+
if !not_zst(iter_args[1].expect_ty(), tcx) {
348+
return None;
349+
}
350+
351+
// We found `slice::iter::Iter`, now, we can work on optimizing it away.
352+
// 1. Get the `ptr.pointer` field - this is the field we will increment.
353+
// We know that Iter::next() takes a &mut self, which can't be a constant(?). So, we only worry about Operand::Move or Operand::Copy, which can be turned into places.
354+
let Some(iter_place) = args[0].node.place() else {
355+
return None;
356+
};
357+
Some(InlineSliceNextCandidate {
358+
iter_place,
359+
iter_adt: *iter_adt,
360+
iter_args,
361+
fn_span: *fn_span,
362+
source_info: *source_info,
363+
destination: *destination,
364+
target: target.as_ref().copied()?,
365+
})
366+
}

compiler/rustc_mir_transform/src/validate.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
956956
.tcx
957957
.normalize_erasing_regions(self.typing_env, dest.ty(self.tcx, args));
958958
if !self.mir_assign_valid_types(src.ty(self.body, self.tcx), dest_ty) {
959-
self.fail(location, "adt field has the wrong type");
959+
self.fail(location, &format!("adt field has the wrong type. src:{:?} dest_ty:{dest_ty:?} src:{src:?}",src.ty(self.body, self.tcx)));
960960
}
961961
}
962962
}

tests/mir-opt/slice_iter.rs

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#[no_mangle]
2+
// EMIT_MIR slice_iter.built.after.mir
3+
fn slice_iter_next<'a>(s_iter: &mut std::slice::Iter<'a, f32>) -> Option<&'a f32> {
4+
s_iter.next()
5+
}

0 commit comments

Comments
 (0)