Skip to content

New pass to optimize ifconditions on integrals to switches on the integer #75370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/librustc_middle/mir/interpret/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ impl<'tcx, Tag> Scalar<Tag> {
self.to_unsigned_with_bit_width(64).map(|v| u64::try_from(v).unwrap())
}

/// Converts the scalar to produce an `u128`. Fails if the scalar is a pointer.
pub fn to_u128(self) -> InterpResult<'static, u128> {
self.to_unsigned_with_bit_width(128)
}

pub fn to_machine_usize(self, cx: &impl HasDataLayout) -> InterpResult<'static, u64> {
let b = self.to_bits(cx.data_layout().pointer_size)?;
Ok(u64::try_from(b).unwrap())
Expand Down Expand Up @@ -535,6 +540,11 @@ impl<'tcx, Tag> Scalar<Tag> {
self.to_signed_with_bit_width(64).map(|v| i64::try_from(v).unwrap())
}

/// Converts the scalar to produce an `i128`. Fails if the scalar is a pointer.
pub fn to_i128(self) -> InterpResult<'static, i128> {
self.to_signed_with_bit_width(128)
}

pub fn to_machine_isize(self, cx: &impl HasDataLayout) -> InterpResult<'static, i64> {
let sz = cx.data_layout().pointer_size;
let b = self.to_bits(sz)?;
Expand Down
13 changes: 13 additions & 0 deletions src/librustc_middle/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,15 @@ pub enum StatementKind<'tcx> {
Nop,
}

impl<'tcx> StatementKind<'tcx> {
pub fn as_assign_mut(&mut self) -> Option<&mut Box<(Place<'tcx>, Rvalue<'tcx>)>> {
match self {
StatementKind::Assign(x) => Some(x),
_ => None,
}
}
}

/// Describes what kind of retag is to be performed.
#[derive(Copy, Clone, TyEncodable, TyDecodable, Debug, PartialEq, Eq, HashStable)]
pub enum RetagKind {
Expand Down Expand Up @@ -1843,6 +1852,10 @@ impl<'tcx> Operand<'tcx> {
})
}

pub fn is_move(&self) -> bool {
matches!(self, Operand::Move(..))
}

/// Convenience helper to make a literal-like constant from a given scalar value.
/// Since this is used to synthesize MIR, assumes `user_ty` is None.
pub fn const_from_scalar(
Expand Down
2 changes: 2 additions & 0 deletions src/librustc_mir/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub mod required_consts;
pub mod rustc_peek;
pub mod simplify;
pub mod simplify_branches;
pub mod simplify_comparison_integral;
pub mod simplify_try;
pub mod uninhabited_enum_branching;
pub mod unreachable_prop;
Expand Down Expand Up @@ -456,6 +457,7 @@ fn run_optimization_passes<'tcx>(
&match_branches::MatchBranchSimplification,
&const_prop::ConstProp,
&simplify_branches::SimplifyBranches::new("after-const-prop"),
&simplify_comparison_integral::SimplifyComparisonIntegral,
&simplify_try::SimplifyArmIdentity,
&simplify_try::SimplifyBranchSame,
&copy_prop::CopyPropagation,
Expand Down
226 changes: 226 additions & 0 deletions src/librustc_mir/transform/simplify_comparison_integral.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
use super::{MirPass, MirSource};
use rustc_middle::{
mir::{
interpret::Scalar, BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement,
StatementKind, TerminatorKind,
},
ty::{Ty, TyCtxt},
};

/// Pass to convert `if` conditions on integrals into switches on the integral.
/// For an example, it turns something like
///
/// ```
/// _3 = Eq(move _4, const 43i32);
/// StorageDead(_4);
/// switchInt(_3) -> [false: bb2, otherwise: bb3];
/// ```
///
/// into:
///
/// ```
/// switchInt(_4) -> [43i32: bb3, otherwise: bb2];
/// ```
pub struct SimplifyComparisonIntegral;

impl<'tcx> MirPass<'tcx> for SimplifyComparisonIntegral {
fn run_pass(&self, _: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
trace!("Running SimplifyComparisonIntegral on {:?}", source);

let helper = OptimizationFinder { body };
let opts = helper.find_optimizations();
let mut storage_deads_to_insert = vec![];
let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
for opt in opts {
trace!("SUCCESS: Applying {:?}", opt);
// replace terminator with a switchInt that switches on the integer directly
let bbs = &mut body.basic_blocks_mut();
let bb = &mut bbs[opt.bb_idx];
// We only use the bits for the untyped, not length checked `values` field. Thus we are
// not using any of the convenience wrappers here and directly access the bits.
let new_value = match opt.branch_value_scalar {
Scalar::Raw { data, .. } => data,
Scalar::Ptr(_) => continue,
};
const FALSE: u128 = 0;
let mut new_targets = opt.targets.clone();
let first_is_false_target = opt.values[0] == FALSE;
match opt.op {
BinOp::Eq => {
// if the assignment was Eq we want the true case to be first
if first_is_false_target {
new_targets.swap(0, 1);
}
}
BinOp::Ne => {
// if the assignment was Ne we want the false case to be first
if !first_is_false_target {
new_targets.swap(0, 1);
}
}
_ => unreachable!(),
}

let terminator = bb.terminator_mut();

// add StorageDead for the place switched on at the top of each target
for bb_idx in new_targets.iter() {
storage_deads_to_insert.push((
*bb_idx,
Statement {
source_info: terminator.source_info,
kind: StatementKind::StorageDead(opt.to_switch_on.local),
},
));
}

terminator.kind = TerminatorKind::SwitchInt {
discr: Operand::Move(opt.to_switch_on),
switch_ty: opt.branch_value_ty,
values: vec![new_value].into(),
targets: new_targets,
};

// delete comparison statement if it the value being switched on was moved, which means it can not be user later on
if opt.can_remove_bin_op_stmt {
bb.statements[opt.bin_op_stmt_idx].make_nop();
} else {
// if the integer being compared to a const integral is being moved into the comparison,
// e.g `_2 = Eq(move _3, const 'x');`
// we want to avoid making a double move later on in the switchInt on _3.
// So to avoid `switchInt(move _3) -> ['x': bb2, otherwise: bb1];`,
// we convert the move in the comparison statement to a copy.

// unwrap is safe as we know this statement is an assign
let box (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();

use Operand::*;
match rhs {
Rvalue::BinaryOp(_, ref mut left @ Move(_), Constant(_)) => {
*left = Copy(opt.to_switch_on);
}
Rvalue::BinaryOp(_, Constant(_), ref mut right @ Move(_)) => {
*right = Copy(opt.to_switch_on);
}
_ => (),
}
}

// remove StorageDead (if it exists) being used in the assign of the comparison
for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
if !matches!(stmt.kind, StatementKind::StorageDead(local) if local == opt.to_switch_on.local)
{
continue;
}
storage_deads_to_remove.push((stmt_idx, opt.bb_idx))
}
}

for (idx, bb_idx) in storage_deads_to_remove {
body.basic_blocks_mut()[bb_idx].statements[idx].make_nop();
}

for (idx, stmt) in storage_deads_to_insert {
body.basic_blocks_mut()[idx].statements.insert(0, stmt);
}
}
}

struct OptimizationFinder<'a, 'tcx> {
body: &'a Body<'tcx>,
}

impl<'a, 'tcx> OptimizationFinder<'a, 'tcx> {
fn find_optimizations(&self) -> Vec<OptimizationInfo<'tcx>> {
self.body
.basic_blocks()
.iter_enumerated()
.filter_map(|(bb_idx, bb)| {
// find switch
let (place_switched_on, values, targets, place_switched_on_moved) = match &bb
.terminator()
.kind
{
rustc_middle::mir::TerminatorKind::SwitchInt {
discr, values, targets, ..
} => Some((discr.place()?, values, targets, discr.is_move())),
_ => None,
}?;

// find the statement that assigns the place being switched on
bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
match &stmt.kind {
rustc_middle::mir::StatementKind::Assign(box (lhs, rhs))
if *lhs == place_switched_on =>
{
match rhs {
Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), left, right) => {
let (branch_value_scalar, branch_value_ty, to_switch_on) =
find_branch_value_info(left, right)?;

Some(OptimizationInfo {
bin_op_stmt_idx: stmt_idx,
bb_idx,
can_remove_bin_op_stmt: place_switched_on_moved,
to_switch_on,
branch_value_scalar,
branch_value_ty,
op: *op,
values: values.clone().into_owned(),
targets: targets.clone(),
})
}
_ => None,
}
}
_ => None,
}
})
})
.collect()
}
}

fn find_branch_value_info<'tcx>(
left: &Operand<'tcx>,
right: &Operand<'tcx>,
) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
// check that either left or right is a constant.
// if any are, we can use the other to switch on, and the constant as a value in a switch
use Operand::*;
match (left, right) {
(Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
| (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
let branch_value_ty = branch_value.literal.ty;
// we only want to apply this optimization if we are matching on integrals (and chars), as it is not possible to switch on floats
if !branch_value_ty.is_integral() && !branch_value_ty.is_char() {
return None;
};
let branch_value_scalar = branch_value.literal.val.try_to_scalar()?;
Some((branch_value_scalar, branch_value_ty, *to_switch_on))
}
_ => None,
}
}

#[derive(Debug)]
struct OptimizationInfo<'tcx> {
/// Basic block to apply the optimization
bb_idx: BasicBlock,
/// Statement index of Eq/Ne assignment that can be removed. None if the assignment can not be removed - i.e the statement is used later on
bin_op_stmt_idx: usize,
/// Can remove Eq/Ne assignment
can_remove_bin_op_stmt: bool,
/// Place that needs to be switched on. This place is of type integral
to_switch_on: Place<'tcx>,
/// Constant to use in switch target value
branch_value_scalar: Scalar,
/// Type of the constant value
branch_value_ty: Ty<'tcx>,
/// Either Eq or Ne
op: BinOp,
/// Current values used in the switch target. This needs to be replaced with the branch_value
values: Vec<u128>,
/// Current targets used in the switch
targets: Vec<BasicBlock>,
}
65 changes: 65 additions & 0 deletions src/test/mir-opt/if-condition-int.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// compile-flags: -O
// EMIT_MIR if_condition_int.opt_u32.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_negative.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_char.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_i8.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.dont_opt_bool.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.opt_multiple_ifs.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.dont_remove_comparison.SimplifyComparisonIntegral.diff
// EMIT_MIR if_condition_int.dont_opt_floats.SimplifyComparisonIntegral.diff

fn opt_u32(x: u32) -> u32 {
if x == 42 { 0 } else { 1 }
}

// don't opt: it is already optimal to switch on the bool
fn dont_opt_bool(x: bool) -> u32 {
if x { 0 } else { 1 }
}

fn opt_char(x: char) -> u32 {
if x == 'x' { 0 } else { 1 }
}

fn opt_i8(x: i8) -> u32 {
if x == 42 { 0 } else { 1 }
}

fn opt_negative(x: i32) -> u32 {
if x == -42 { 0 } else { 1 }
}

fn opt_multiple_ifs(x: u32) -> u32 {
if x == 42 {
0
} else if x != 21 {
1
} else {
2
}
}

// test that we optimize, but do not remove the b statement, as that is used later on
fn dont_remove_comparison(a: i8) -> i32 {
let b = a == 17;
match b {
false => 10 + b as i32,
true => 100 + b as i32,
}
}

// test that we do not optimize on floats
fn dont_opt_floats(a: f32) -> i32 {
if a == -42.0 { 0 } else { 1 }
}

fn main() {
opt_u32(0);
opt_char('0');
opt_i8(22);
dont_opt_bool(false);
opt_negative(0);
opt_multiple_ifs(0);
dont_remove_comparison(11);
dont_opt_floats(1.0);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
- // MIR for `dont_opt_bool` before SimplifyComparisonIntegral
+ // MIR for `dont_opt_bool` after SimplifyComparisonIntegral

fn dont_opt_bool(_1: bool) -> u32 {
debug x => _1; // in scope 0 at $DIR/if-condition-int.rs:16:18: 16:19
let mut _0: u32; // return place in scope 0 at $DIR/if-condition-int.rs:16:30: 16:33
let mut _2: bool; // in scope 0 at $DIR/if-condition-int.rs:17:8: 17:9

bb0: {
StorageLive(_2); // scope 0 at $DIR/if-condition-int.rs:17:8: 17:9
_2 = _1; // scope 0 at $DIR/if-condition-int.rs:17:8: 17:9
switchInt(_2) -> [false: bb1, otherwise: bb2]; // scope 0 at $DIR/if-condition-int.rs:17:5: 17:26
}

bb1: {
_0 = const 1_u32; // scope 0 at $DIR/if-condition-int.rs:17:23: 17:24
goto -> bb3; // scope 0 at $DIR/if-condition-int.rs:17:5: 17:26
}

bb2: {
_0 = const 0_u32; // scope 0 at $DIR/if-condition-int.rs:17:12: 17:13
goto -> bb3; // scope 0 at $DIR/if-condition-int.rs:17:5: 17:26
}

bb3: {
StorageDead(_2); // scope 0 at $DIR/if-condition-int.rs:18:1: 18:2
return; // scope 0 at $DIR/if-condition-int.rs:18:2: 18:2
}
}

Loading