Skip to content

Commit 5e3f1b1

Browse files
committed
Auto merge of #75382 - JulianKnodt:match_branches, r=oli-obk
First iteration of simplify match branches This is a simple MIR pass that attempts to convert ``` bb0: { StorageLive(_2); _3 = discriminant(_1); switchInt(move _3) -> [0isize: bb2, otherwise: bb1]; } bb1: { _2 = const false; goto -> bb3; } bb2: { _2 = const true; goto -> bb3; } ``` into ``` bb0: { StorageLive(_2); _3 = discriminant(_1); _2 = _3 == 0; goto -> bb3; } ``` There are still missing components(like checking if the assignments are bools). Was hoping that this could get some review though. Handles #75141 r? @oli-obk
2 parents b6396b7 + 46e5699 commit 5e3f1b1

5 files changed

+240
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use crate::transform::{MirPass, MirSource};
2+
use rustc_middle::mir::*;
3+
use rustc_middle::ty::TyCtxt;
4+
5+
pub struct MatchBranchSimplification;
6+
7+
// What's the intent of this pass?
8+
// If one block is found that switches between blocks which both go to the same place
9+
// AND both of these blocks set a similar const in their ->
10+
// condense into 1 block based on discriminant AND goto the destination afterwards
11+
12+
impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
13+
fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) {
14+
let param_env = tcx.param_env(src.def_id());
15+
let bbs = body.basic_blocks_mut();
16+
'outer: for bb_idx in bbs.indices() {
17+
let (discr, val, switch_ty, first, second) = match bbs[bb_idx].terminator().kind {
18+
TerminatorKind::SwitchInt {
19+
discr: Operand::Move(ref place),
20+
switch_ty,
21+
ref targets,
22+
ref values,
23+
..
24+
} if targets.len() == 2 && values.len() == 1 => {
25+
(place, values[0], switch_ty, targets[0], targets[1])
26+
}
27+
// Only optimize switch int statements
28+
_ => continue,
29+
};
30+
31+
// Check that destinations are identical, and if not, then don't optimize this block
32+
if &bbs[first].terminator().kind != &bbs[second].terminator().kind {
33+
continue;
34+
}
35+
36+
// Check that blocks are assignments of consts to the same place or same statement,
37+
// and match up 1-1, if not don't optimize this block.
38+
let first_stmts = &bbs[first].statements;
39+
let scnd_stmts = &bbs[second].statements;
40+
if first_stmts.len() != scnd_stmts.len() {
41+
continue;
42+
}
43+
for (f, s) in first_stmts.iter().zip(scnd_stmts.iter()) {
44+
match (&f.kind, &s.kind) {
45+
// If two statements are exactly the same just ignore them.
46+
(f_s, s_s) if f_s == s_s => (),
47+
48+
(
49+
StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
50+
StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
51+
) if lhs_f == lhs_s => {
52+
if let Some(f_c) = f_c.literal.try_eval_bool(tcx, param_env) {
53+
// This should also be a bool because it's writing to the same place
54+
let s_c = s_c.literal.try_eval_bool(tcx, param_env).unwrap();
55+
if f_c != s_c {
56+
// have to check this here because f_c & s_c might have
57+
// different spans.
58+
continue;
59+
}
60+
}
61+
continue 'outer;
62+
}
63+
// If there are not exclusively assignments, then ignore this
64+
_ => continue 'outer,
65+
}
66+
}
67+
// Take owenership of items now that we know we can optimize.
68+
let discr = discr.clone();
69+
let (from, first) = bbs.pick2_mut(bb_idx, first);
70+
71+
let new_stmts = first.statements.iter().cloned().map(|mut s| {
72+
if let StatementKind::Assign(box (_, ref mut rhs)) = s.kind {
73+
if let Rvalue::Use(Operand::Constant(c)) = rhs {
74+
let size = tcx.layout_of(param_env.and(switch_ty)).unwrap().size;
75+
let const_cmp = Operand::const_from_scalar(
76+
tcx,
77+
switch_ty,
78+
crate::interpret::Scalar::from_uint(val, size),
79+
rustc_span::DUMMY_SP,
80+
);
81+
if let Some(c) = c.literal.try_eval_bool(tcx, param_env) {
82+
let op = if c { BinOp::Eq } else { BinOp::Ne };
83+
*rhs = Rvalue::BinaryOp(op, Operand::Move(discr), const_cmp);
84+
}
85+
}
86+
}
87+
s
88+
});
89+
from.statements.extend(new_stmts);
90+
from.terminator_mut().kind = first.terminator().kind.clone();
91+
}
92+
}
93+
}

src/librustc_mir/transform/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub mod generator;
2929
pub mod inline;
3030
pub mod instcombine;
3131
pub mod instrument_coverage;
32+
pub mod match_branches;
3233
pub mod no_landing_pads;
3334
pub mod nrvo;
3435
pub mod promote_consts;
@@ -440,6 +441,7 @@ fn run_optimization_passes<'tcx>(
440441
// with async primitives.
441442
&generator::StateTransform,
442443
&instcombine::InstCombine,
444+
&match_branches::MatchBranchSimplification,
443445
&const_prop::ConstProp,
444446
&simplify_branches::SimplifyBranches::new("after-const-prop"),
445447
&simplify_try::SimplifyArmIdentity,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
- // MIR for `foo` before MatchBranchSimplification
2+
+ // MIR for `foo` after MatchBranchSimplification
3+
4+
fn foo(_1: std::option::Option<()>) -> () {
5+
debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11
6+
let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25
7+
let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
8+
let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
9+
10+
bb0: {
11+
StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
12+
_3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
13+
- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
14+
+ _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
15+
+ // ty::Const
16+
+ // + ty: isize
17+
+ // + val: Value(Scalar(0x00000000))
18+
+ // mir::Constant
19+
+ // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1
20+
+ // + literal: Const { ty: isize, val: Value(Scalar(0x00000000)) }
21+
+ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
22+
}
23+
24+
bb1: {
25+
_2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
26+
// ty::Const
27+
// + ty: bool
28+
// + val: Value(Scalar(0x00))
29+
// mir::Constant
30+
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
31+
// + literal: Const { ty: bool, val: Value(Scalar(0x00)) }
32+
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
33+
}
34+
35+
bb2: {
36+
_2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
37+
// ty::Const
38+
// + ty: bool
39+
// + val: Value(Scalar(0x01))
40+
// mir::Constant
41+
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
42+
// + literal: Const { ty: bool, val: Value(Scalar(0x01)) }
43+
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
44+
}
45+
46+
bb3: {
47+
switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
48+
}
49+
50+
bb4: {
51+
_0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
52+
// ty::Const
53+
// + ty: ()
54+
// + val: Value(Scalar(<ZST>))
55+
// mir::Constant
56+
// + span: $DIR/matches_reduce_branches.rs:5:5: 7:6
57+
// + literal: Const { ty: (), val: Value(Scalar(<ZST>)) }
58+
goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
59+
}
60+
61+
bb5: {
62+
StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2
63+
return; // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2
64+
}
65+
}
66+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
- // MIR for `foo` before MatchBranchSimplification
2+
+ // MIR for `foo` after MatchBranchSimplification
3+
4+
fn foo(_1: std::option::Option<()>) -> () {
5+
debug bar => _1; // in scope 0 at $DIR/matches_reduce_branches.rs:4:8: 4:11
6+
let mut _0: (); // return place in scope 0 at $DIR/matches_reduce_branches.rs:4:25: 4:25
7+
let mut _2: bool; // in scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
8+
let mut _3: isize; // in scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
9+
10+
bb0: {
11+
StorageLive(_2); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
12+
_3 = discriminant(_1); // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
13+
- switchInt(move _3) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
14+
+ _2 = Eq(move _3, const 0_isize); // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
15+
+ // ty::Const
16+
+ // + ty: isize
17+
+ // + val: Value(Scalar(0x0000000000000000))
18+
+ // mir::Constant
19+
+ // + span: $DIR/matches_reduce_branches.rs:1:1: 1:1
20+
+ // + literal: Const { ty: isize, val: Value(Scalar(0x0000000000000000)) }
21+
+ goto -> bb3; // scope 0 at $DIR/matches_reduce_branches.rs:5:22: 5:26
22+
}
23+
24+
bb1: {
25+
_2 = const false; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
26+
// ty::Const
27+
// + ty: bool
28+
// + val: Value(Scalar(0x00))
29+
// mir::Constant
30+
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
31+
// + literal: Const { ty: bool, val: Value(Scalar(0x00)) }
32+
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
33+
}
34+
35+
bb2: {
36+
_2 = const true; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
37+
// ty::Const
38+
// + ty: bool
39+
// + val: Value(Scalar(0x01))
40+
// mir::Constant
41+
// + span: $SRC_DIR/core/src/macros/mod.rs:LL:COL
42+
// + literal: Const { ty: bool, val: Value(Scalar(0x01)) }
43+
goto -> bb3; // scope 0 at $SRC_DIR/core/src/macros/mod.rs:LL:COL
44+
}
45+
46+
bb3: {
47+
switchInt(_2) -> [false: bb4, otherwise: bb5]; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
48+
}
49+
50+
bb4: {
51+
_0 = const (); // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
52+
// ty::Const
53+
// + ty: ()
54+
// + val: Value(Scalar(<ZST>))
55+
// mir::Constant
56+
// + span: $DIR/matches_reduce_branches.rs:5:5: 7:6
57+
// + literal: Const { ty: (), val: Value(Scalar(<ZST>)) }
58+
goto -> bb5; // scope 0 at $DIR/matches_reduce_branches.rs:5:5: 7:6
59+
}
60+
61+
bb5: {
62+
StorageDead(_2); // scope 0 at $DIR/matches_reduce_branches.rs:8:1: 8:2
63+
return; // scope 0 at $DIR/matches_reduce_branches.rs:8:2: 8:2
64+
}
65+
}
66+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// EMIT_MIR_FOR_EACH_BIT_WIDTH
2+
// EMIT_MIR matches_reduce_branches.foo.MatchBranchSimplification.diff
3+
4+
fn foo(bar: Option<()>) {
5+
if matches!(bar, None) {
6+
()
7+
}
8+
}
9+
10+
fn main() {
11+
let _ = foo(None);
12+
let _ = foo(Some(()));
13+
}

0 commit comments

Comments
 (0)