|
| 1 | +use crate::transform::{MirPass, MirSource}; |
| 2 | +use rustc_middle::mir::*; |
| 3 | +use rustc_middle::ty::TyCtxt; |
| 4 | + |
| 5 | +pub struct MatchBranchSimplification; |
| 6 | + |
| 7 | +// What's the intent of this pass? |
| 8 | +// If one block is found that switches between blocks which both go to the same place |
| 9 | +// AND both of these blocks set a similar const in their -> |
| 10 | +// condense into 1 block based on discriminant AND goto the destination afterwards |
| 11 | + |
| 12 | +impl<'tcx> MirPass<'tcx> for MatchBranchSimplification { |
| 13 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) { |
| 14 | + let param_env = tcx.param_env(src.def_id()); |
| 15 | + let bbs = body.basic_blocks_mut(); |
| 16 | + 'outer: for bb_idx in bbs.indices() { |
| 17 | + let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind { |
| 18 | + TerminatorKind::SwitchInt { |
| 19 | + discr: Operand::Move(ref place), |
| 20 | + switch_ty, |
| 21 | + ref targets, |
| 22 | + ref values, |
| 23 | + .. |
| 24 | + } if targets.len() == 2 && values.len() == 1 => { |
| 25 | + (place, values[0], switch_ty, targets[0], targets[1]) |
| 26 | + } |
| 27 | + // Only optimize switch int statements |
| 28 | + _ => continue, |
| 29 | + }; |
| 30 | + |
| 31 | + // Check that destinations are identical, and if not, then don't optimize this block |
| 32 | + if &bbs[first].terminator().kind != &bbs[second].terminator().kind { |
| 33 | + continue; |
| 34 | + } |
| 35 | + |
| 36 | + // Check that blocks are assignments of consts to the same place or same statement, |
| 37 | + // and match up 1-1, if not don't optimize this block. |
| 38 | + let first_stmts = &bbs[first].statements; |
| 39 | + let scnd_stmts = &bbs[second].statements; |
| 40 | + if first_stmts.len() != scnd_stmts.len() { |
| 41 | + continue; |
| 42 | + } |
| 43 | + for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) { |
| 44 | + match (&f.kind, &s.kind) { |
| 45 | + // If two statements are exactly the same just ignore them. |
| 46 | + (f_s, s_s) if f_s == s_s => (), |
| 47 | + |
| 48 | + ( |
| 49 | + StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), |
| 50 | + StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), |
| 51 | + ) if lhs_f == lhs_s => { |
| 52 | + if let Some(f_c) = f_c.literal.try_eval_bool(tcx, param_env) { |
| 53 | + // This should also be a bool because it's writing to the same place |
| 54 | + let s_c = s_c.literal.try_eval_bool(tcx, param_env).unwrap(); |
| 55 | + if f_c != s_c { |
| 56 | + // have to check this here because f_c & s_c might have |
| 57 | + // different spans. |
| 58 | + continue; |
| 59 | + } |
| 60 | + } |
| 61 | + continue 'outer; |
| 62 | + } |
| 63 | + // If there are not exclusively assignments, then ignore this |
| 64 | + _ => continue 'outer, |
| 65 | + } |
| 66 | + } |
| 67 | + // Take owenership of items now that we know we can optimize. |
| 68 | + let discr = discr.clone(); |
| 69 | + let (from, first) = bbs.pick2_mut(bb_idx, first); |
| 70 | + |
| 71 | + let new_stmts = first.statements.iter().cloned().map(|mut s| { |
| 72 | + if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind { |
| 73 | + if let Rvalue::Use(Operand::Constant(c)) = rhs { |
| 74 | + let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size; |
| 75 | + let const_cmp = Operand::const_from_scalar( |
| 76 | + tcx, |
| 77 | + switch_ty, |
| 78 | + crate::interpret::Scalar::from_uint(val, size), |
| 79 | + rustc_span::DUMMY_SP, |
| 80 | + ); |
| 81 | + if let Some(c) = c.literal.try_eval_bool(tcx, param_env) { |
| 82 | + let op = if c { BinOp::Eq } else { BinOp::Ne }; |
| 83 | + *rhs = Rvalue::BinaryOp(op, Operand::Move(discr), const_cmp); |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + s |
| 88 | + }); |
| 89 | + from.statements.extend(new_stmts); |
| 90 | + from.terminator_mut().kind = first.terminator().kind.clone(); |
| 91 | + } |
| 92 | + } |
| 93 | +} |
0 commit comments