|
| 1 | +//! Replaces calls to `Iter::next` with small, specialized MIR implementations, for some common iterators. |
| 2 | +use rustc_abi::{FieldIdx, VariantIdx}; |
| 3 | +use rustc_index::IndexVec; |
| 4 | +use rustc_middle::mir::interpret::Scalar; |
| 5 | +use rustc_middle::mir::{SourceInfo, *}; |
| 6 | +use rustc_middle::ty::{self, AdtDef, AdtKind, GenericArgs, Ty, TyCtxt}; |
| 7 | +use rustc_span::Span; |
| 8 | +use rustc_type_ir::inherent::*; |
| 9 | +use tracing::trace; |
| 10 | + |
| 11 | +use crate::hir::def_id::{CrateNum, DefId}; |
| 12 | + |
| 13 | +pub(super) enum StreamlineIter { |
| 14 | + Working { core: CrateNum, iter_next: DefId }, |
| 15 | + Disabled, |
| 16 | +} |
| 17 | +impl StreamlineIter { |
| 18 | + pub(crate) fn new(tcx: TyCtxt<'_>) -> Self { |
| 19 | + let Some(iter_next) = tcx.lang_items().next_fn() else { |
| 20 | + return Self::Disabled; |
| 21 | + }; |
| 22 | + let core = iter_next.krate; |
| 23 | + Self::Working { core, iter_next } |
| 24 | + } |
| 25 | +} |
| 26 | +impl<'tcx> crate::MirPass<'tcx> for StreamlineIter { |
| 27 | + fn is_enabled(&self, sess: &rustc_session::Session) -> bool { |
| 28 | + sess.mir_opt_level() > 1 && (matches!(self, StreamlineIter::Working { .. })) |
| 29 | + } |
| 30 | + // Temporary allow for dev purposes |
| 31 | + #[allow(unused_variables, unused_mut, unreachable_code)] |
| 32 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 33 | + trace!("Running StreamlineIter on {:?}", body.source); |
| 34 | + let Self::Working { core, iter_next } = self else { |
| 35 | + return; |
| 36 | + }; |
| 37 | + let mut bbs = body.basic_blocks.as_mut_preserves_cfg(); |
| 38 | + let locals = &mut body.local_decls; |
| 39 | + // If any optimizations were pefromed, invalidate the cache. |
| 40 | + let mut cfg_invalid = false; |
| 41 | + |
| 42 | + // 1st. Go trough all terminators, find calls. |
| 43 | + for bid in (0..(bbs.len())).into_iter().map(BasicBlock::from_usize) { |
| 44 | + let mut bb = &bbs[bid]; |
| 45 | + // Check if this is the call to std::slice::Iter::next OR std::slice::IterMut::next |
| 46 | + let Some(InlineSliceNextCandidate { |
| 47 | + iter_place, |
| 48 | + iter_adt, |
| 49 | + iter_args, |
| 50 | + fn_span, |
| 51 | + source_info, |
| 52 | + destination, |
| 53 | + target, |
| 54 | + }) = terminator_iter_next(&bb.terminator, *iter_next, *core, tcx) |
| 55 | + else { |
| 56 | + continue; |
| 57 | + }; |
| 58 | + // Find the relevant field: |
| 59 | + let (notnull_idx, notnull_ty) = iter_adt |
| 60 | + .variant(VariantIdx::ZERO) |
| 61 | + .fields |
| 62 | + .iter() |
| 63 | + .enumerate() |
| 64 | + .map(|(idx, field)| (FieldIdx::from_usize(idx), field.ty(tcx, iter_args))) |
| 65 | + .filter(|(idx, ty)| match ty.kind() { |
| 66 | + ty::Adt(adt, _) => !adt.is_phantom_data(), |
| 67 | + _ => false, |
| 68 | + }) |
| 69 | + .next() |
| 70 | + .unwrap(); |
| 71 | + let iter_place = tcx.mk_place_deref(iter_place); |
| 72 | + let ptr_nonull = tcx.mk_place_field(iter_place, notnull_idx, notnull_ty); |
| 73 | + let ty::Adt(non_null_adt, on_null_arg) = notnull_ty.kind() else { |
| 74 | + continue; |
| 75 | + }; |
| 76 | + let (inner_idx, inner_t) = non_null_adt |
| 77 | + .variant(VariantIdx::ZERO) |
| 78 | + .fields |
| 79 | + .iter() |
| 80 | + .enumerate() |
| 81 | + .map(|(idx, field)| (FieldIdx::from_usize(idx), field.ty(tcx, on_null_arg))) |
| 82 | + .filter(|(idx, ty)| match ty.kind() { |
| 83 | + ty::RawPtr(_, _) => true, |
| 84 | + _ => false, |
| 85 | + }) |
| 86 | + .next() |
| 87 | + .unwrap(); |
| 88 | + let pointer = tcx.mk_place_field(ptr_nonull, inner_idx, inner_t); |
| 89 | + // Increment pointer |
| 90 | + let val = Operand::Copy(pointer); |
| 91 | + let one = Operand::const_from_scalar( |
| 92 | + tcx, |
| 93 | + tcx.types.usize, |
| 94 | + Scalar::from_target_usize(1, &tcx), |
| 95 | + fn_span, |
| 96 | + ); |
| 97 | + let offset = Rvalue::BinaryOp(BinOp::Offset, Box::new((val, one))); |
| 98 | + let incr = |
| 99 | + Statement { kind: StatementKind::Assign(Box::new((pointer, offset))), source_info }; |
| 100 | + // Allocate the check & cast_end local: |
| 101 | + let check = locals.push(LocalDecl::new(tcx.types.bool, fn_span)); |
| 102 | + // Bounds check |
| 103 | + let (idx, ty) = iter_adt |
| 104 | + .variant(VariantIdx::ZERO) |
| 105 | + .fields |
| 106 | + .iter() |
| 107 | + .enumerate() |
| 108 | + .map(|(idx, field)| (FieldIdx::from_usize(idx), field.ty(tcx, iter_args))) |
| 109 | + .filter(|(idx, ty)| match ty.kind() { |
| 110 | + ty::RawPtr(_, _) => true, |
| 111 | + _ => false, |
| 112 | + }) |
| 113 | + .next() |
| 114 | + .unwrap(); |
| 115 | + |
| 116 | + let end_ptr = tcx.mk_place_field(iter_place, idx, ty); |
| 117 | + let end_ptr = Operand::Copy(end_ptr); |
| 118 | + let ptr = Operand::Copy(pointer); |
| 119 | + let pointer_ty = pointer.ty(locals, tcx).ty; |
| 120 | + let end_ptr_after_cast = locals.push(LocalDecl::new(pointer_ty, fn_span)); |
| 121 | + let cast_end_ptr = Rvalue::Cast(CastKind::PtrToPtr, end_ptr, pointer_ty); |
| 122 | + let ptr_cast = Statement { |
| 123 | + kind: StatementKind::Assign(Box::new((end_ptr_after_cast.into(), cast_end_ptr))), |
| 124 | + source_info, |
| 125 | + }; |
| 126 | + |
| 127 | + let is_empty = Rvalue::BinaryOp( |
| 128 | + BinOp::Eq, |
| 129 | + Box::new((ptr, Operand::Copy(end_ptr_after_cast.into()))), |
| 130 | + ); |
| 131 | + let check_iter_empty = Statement { |
| 132 | + kind: StatementKind::Assign(Box::new((check.into(), is_empty))), |
| 133 | + source_info, |
| 134 | + }; |
| 135 | + |
| 136 | + // Create the Some and None blocks |
| 137 | + let rejoin = Terminator { kind: TerminatorKind::Goto { target }, source_info }; |
| 138 | + let mut some_block = BasicBlockData::new(Some(rejoin.clone()), false); |
| 139 | + let mut none_block = BasicBlockData::new(Some(rejoin), false); |
| 140 | + // Create the None value |
| 141 | + let dst_ty = destination.ty(locals, tcx); |
| 142 | + let ty::Adt(option_adt, option_gargs) = dst_ty.ty.kind() else { |
| 143 | + continue; |
| 144 | + }; |
| 145 | + let none_val = Rvalue::Aggregate( |
| 146 | + Box::new(AggregateKind::Adt( |
| 147 | + option_adt.did(), |
| 148 | + VariantIdx::ZERO, |
| 149 | + option_gargs, |
| 150 | + None, |
| 151 | + None, |
| 152 | + )), |
| 153 | + IndexVec::new(), |
| 154 | + ); |
| 155 | + let set_none = Statement { |
| 156 | + kind: StatementKind::Assign(Box::new((destination, none_val))), |
| 157 | + source_info, |
| 158 | + }; |
| 159 | + none_block.statements.push(set_none); |
| 160 | + // Cast the pointer to a refernece, preserving lifetimes. |
| 161 | + let ref_ty = option_gargs[0].expect_ty(); |
| 162 | + let ref_local = locals.push(LocalDecl::new(ref_ty, fn_span)); |
| 163 | + |
| 164 | + let ty::Ref(region, _, muta) = ref_ty.kind() else { |
| 165 | + continue; |
| 166 | + }; |
| 167 | + let pointer_local = locals.push(LocalDecl::new(pointer_ty, fn_span)); |
| 168 | + let pointer_assign = Rvalue::Use(Operand::Copy(pointer)); |
| 169 | + let pointer_assign = Statement { |
| 170 | + kind: StatementKind::Assign(Box::new((pointer_local.into(), pointer_assign))), |
| 171 | + source_info, |
| 172 | + }; |
| 173 | + let borrow = if *muta == Mutability::Not { |
| 174 | + BorrowKind::Shared |
| 175 | + } else { |
| 176 | + BorrowKind::Mut { kind: MutBorrowKind::Default } |
| 177 | + }; |
| 178 | + let rf = Rvalue::Ref(*region, borrow, tcx.mk_place_deref(pointer_local.into())); |
| 179 | + let rf = Statement { |
| 180 | + kind: StatementKind::Assign(Box::new((ref_local.into(), rf))), |
| 181 | + source_info, |
| 182 | + }; |
| 183 | + let some_val = Rvalue::Aggregate( |
| 184 | + Box::new(AggregateKind::Adt( |
| 185 | + option_adt.did(), |
| 186 | + VariantIdx::from_usize(1), |
| 187 | + option_gargs, |
| 188 | + None, |
| 189 | + None, |
| 190 | + )), |
| 191 | + [Operand::Move(ref_local.into())].into(), |
| 192 | + ); |
| 193 | + let set_some = Statement { |
| 194 | + kind: StatementKind::Assign(Box::new((destination, some_val))), |
| 195 | + source_info, |
| 196 | + }; |
| 197 | + some_block.statements.push(pointer_assign); |
| 198 | + some_block.statements.push(rf); |
| 199 | + some_block.statements.push(incr); |
| 200 | + some_block.statements.push(set_some); |
| 201 | + |
| 202 | + // Get the new blocks in place - this invalidates caches! |
| 203 | + cfg_invalid = true; |
| 204 | + let some_bb = bbs.push(some_block); |
| 205 | + let none_bb = bbs.push(none_block); |
| 206 | + |
| 207 | + // Change the original block. |
| 208 | + let mut bb = &mut bbs[bid]; |
| 209 | + bb.terminator = Some(Terminator { |
| 210 | + kind: TerminatorKind::SwitchInt { |
| 211 | + discr: Operand::Move(check.into()), |
| 212 | + targets: SwitchTargets::new(std::iter::once((0, some_bb)), none_bb), |
| 213 | + }, |
| 214 | + source_info, |
| 215 | + }); |
| 216 | + bb.statements.push(ptr_cast); |
| 217 | + bb.statements.push(check_iter_empty); |
| 218 | + } |
| 219 | + if cfg_invalid { |
| 220 | + body.basic_blocks.invalidate_cfg_cache(); |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | + fn is_required(&self) -> bool { |
| 225 | + true |
| 226 | + } |
| 227 | +} |
| 228 | +fn not_zst<'tcx>(t: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> bool { |
| 229 | + match t.kind() { |
| 230 | + ty::Uint(_) |
| 231 | + | ty::Int(_) |
| 232 | + | ty::Bool |
| 233 | + | ty::Float(_) |
| 234 | + | ty::Char |
| 235 | + | ty::Ref(..) |
| 236 | + | ty::RawPtr(..) |
| 237 | + | ty::FnPtr(..) => true, |
| 238 | + ty::Tuple(elements) => elements.iter().any(|ty| not_zst(ty, tcx)), |
| 239 | + ty::Array(elem, count) if count.try_to_target_usize(tcx).is_some_and(|count| count > 0) => { |
| 240 | + not_zst(*elem, tcx) |
| 241 | + } |
| 242 | + ty::Array(_, _) => false, |
| 243 | + ty::Never | ty::FnDef(..) => false, |
| 244 | + ty::Adt(def, args) => match def.adt_kind() { |
| 245 | + AdtKind::Enum => def.variants().len() > 1, |
| 246 | + AdtKind::Struct | AdtKind::Union => def |
| 247 | + .variant(VariantIdx::ZERO) |
| 248 | + .fields |
| 249 | + .iter() |
| 250 | + .any(|field| not_zst(field.ty(tcx, args), tcx)), |
| 251 | + }, |
| 252 | + // Generic's, can't determine if they are not-zst's at compile time. |
| 253 | + ty::Param(..) | ty::Alias(..) | ty::Bound(..) => false, |
| 254 | + // Those should not occur here, but I still handle them just in case. |
| 255 | + ty::Str | ty::Slice(..) | ty::Foreign(_) | ty::Dynamic(..) => false, |
| 256 | + ty::Pat(..) | ty::UnsafeBinder(..) | ty::Infer(..) | ty::Placeholder(_) | ty::Error(_) => { |
| 257 | + false |
| 258 | + } |
| 259 | + // There are ways to check if those are ZSTs, but this is not worth it ATM. |
| 260 | + ty::Closure(..) |
| 261 | + | ty::CoroutineClosure(..) |
| 262 | + | ty::Coroutine(..) |
| 263 | + | ty::CoroutineWitness(..) => false, |
| 264 | + } |
| 265 | +} |
| 266 | +//-Copt-level=3 -Zmir-opt-level=3 --emit=llvm-ir -C debug-assertions=no |
| 267 | +struct InlineSliceNextCandidate<'tcx> { |
| 268 | + iter_place: Place<'tcx>, |
| 269 | + iter_adt: AdtDef<'tcx>, |
| 270 | + iter_args: &'tcx GenericArgs<'tcx>, |
| 271 | + fn_span: Span, |
| 272 | + source_info: SourceInfo, |
| 273 | + destination: Place<'tcx>, |
| 274 | + target: BasicBlock, |
| 275 | +} |
| 276 | +/// This function checks if this is a call to `std::slice::Iter::next` OR `std::slice::IterMut::next`. |
| 277 | +/// Currently, it uses a bunch of ulgy things to do so, but if those iterators become lang items, then |
| 278 | +/// this could be replaced by a simple DefID check. |
| 279 | +#[allow(unreachable_code, unused_variables)] |
| 280 | +fn terminator_iter_next<'tcx>( |
| 281 | + terminator: &Option<Terminator<'tcx>>, |
| 282 | + iter_next: DefId, |
| 283 | + core: CrateNum, |
| 284 | + tcx: TyCtxt<'tcx>, |
| 285 | +) -> Option<InlineSliceNextCandidate<'tcx>> { |
| 286 | + use rustc_type_ir::inherent::*; |
| 287 | + let Terminator { kind, source_info } = terminator.as_ref()?; |
| 288 | + let TerminatorKind::Call { |
| 289 | + ref func, |
| 290 | + ref args, |
| 291 | + destination, |
| 292 | + target, |
| 293 | + unwind: _, |
| 294 | + call_source: _, |
| 295 | + fn_span, |
| 296 | + } = kind |
| 297 | + else { |
| 298 | + return None; |
| 299 | + }; |
| 300 | + // 2. Check that the `func` of the call is known. |
| 301 | + let func = func.constant()?; |
| 302 | + // 3. Check that the `func` is FnDef |
| 303 | + let ty::FnDef(defid, generic_args) = func.ty().kind() else { |
| 304 | + return None; |
| 305 | + }; |
| 306 | + // 4. Check that this is Iter::next |
| 307 | + if *defid != iter_next { |
| 308 | + return None; |
| 309 | + } |
| 310 | + // 5. Extract parts of the iterator |
| 311 | + let iter_ty = generic_args[0].expect_ty(); |
| 312 | + let ty::Adt(iter_adt, iter_args) = iter_ty.kind() else { |
| 313 | + return None; |
| 314 | + }; |
| 315 | + if iter_adt.did().krate != core { |
| 316 | + return None; |
| 317 | + } |
| 318 | + // 6. Check its argument count - this is a short, cheap check |
| 319 | + if iter_args.len() != 2 { |
| 320 | + return None; |
| 321 | + } |
| 322 | + // 7. Check that the first arg is a lifetime |
| 323 | + if iter_args[0].as_region().is_none() { |
| 324 | + return None; |
| 325 | + } |
| 326 | + // 8. Check that this ADT is a struct, and has 3 fields. |
| 327 | + if !iter_adt.is_struct() { |
| 328 | + return None; |
| 329 | + } |
| 330 | + if iter_adt.all_fields().count() != 3 { |
| 331 | + return None; |
| 332 | + } |
| 333 | + // Check that it has a *const T field. |
| 334 | + if !iter_adt.all_field_tys(tcx).skip_binder().into_iter().any(|ty| match ty.kind() { |
| 335 | + ty::RawPtr(_, _) => true, |
| 336 | + _ => false, |
| 337 | + }) { |
| 338 | + return None; |
| 339 | + } |
| 340 | + // 7. Check that the name of this ADT is `slice::iter::Iter`. This is a janky way to check if this is the iterator we are interested in. |
| 341 | + let name = format!("{:?}", iter_adt.did()); |
| 342 | + if !name.as_str().contains("slice::iter::Iter") { |
| 343 | + return None; |
| 344 | + } |
| 345 | + // We now know this is a slice iterator - so we can optimize it ! |
| 346 | + // Check if we know if this is not a `zst` |
| 347 | + if !not_zst(iter_args[1].expect_ty(), tcx) { |
| 348 | + return None; |
| 349 | + } |
| 350 | + |
| 351 | + // We found `slice::iter::Iter`, now, we can work on optimizing it away. |
| 352 | + // 1. Get the `ptr.pointer` field - this is the field we will increment. |
| 353 | + // We know that Iter::next() takes a &mut self, which can't be a constant(?). So, we only worry about Operand::Move or Operand::Copy, which can be turned into places. |
| 354 | + let Some(iter_place) = args[0].node.place() else { |
| 355 | + return None; |
| 356 | + }; |
| 357 | + Some(InlineSliceNextCandidate { |
| 358 | + iter_place, |
| 359 | + iter_adt: *iter_adt, |
| 360 | + iter_args, |
| 361 | + fn_span: *fn_span, |
| 362 | + source_info: *source_info, |
| 363 | + destination: *destination, |
| 364 | + target: target.as_ref().copied()?, |
| 365 | + }) |
| 366 | +} |
0 commit comments