Skip to content

Commit 3e6729f

Browse files
committed
Use an interpreter in jump threading.
1 parent c2354aa commit 3e6729f

File tree

4 files changed

+196
-24
lines changed

4 files changed

+196
-24
lines changed

compiler/rustc_mir_transform/src/jump_threading.rs

+74-24
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,19 @@
3636
//! cost by `MAX_COST`.
3737
3838
use rustc_arena::DroplessArena;
39+
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
3940
use rustc_data_structures::fx::FxHashSet;
4041
use rustc_index::bit_set::BitSet;
4142
use rustc_index::IndexVec;
43+
use rustc_middle::mir::interpret::Scalar;
4244
use rustc_middle::mir::visit::Visitor;
4345
use rustc_middle::mir::*;
44-
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
46+
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
4547
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
48+
use rustc_span::DUMMY_SP;
4649

4750
use crate::cost_checker::CostChecker;
51+
use crate::dataflow_const_prop::DummyMachine;
4852

4953
pub struct JumpThreading;
5054

@@ -70,6 +74,7 @@ impl<'tcx> MirPass<'tcx> for JumpThreading {
7074
let mut finder = TOFinder {
7175
tcx,
7276
param_env,
77+
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
7378
body,
7479
arena: &arena,
7580
map: &map,
@@ -141,6 +146,7 @@ struct ThreadingOpportunity {
141146
struct TOFinder<'tcx, 'a> {
142147
tcx: TyCtxt<'tcx>,
143148
param_env: ty::ParamEnv<'tcx>,
149+
ecx: InterpCx<'tcx, 'tcx, DummyMachine>,
144150
body: &'a Body<'tcx>,
145151
map: &'a Map,
146152
loop_headers: &'a BitSet<BasicBlock>,
@@ -328,25 +334,75 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
328334
}
329335

330336
#[instrument(level = "trace", skip(self))]
331-
fn process_operand(
337+
fn process_immediate(
332338
&mut self,
333339
bb: BasicBlock,
334340
lhs: PlaceIndex,
335-
rhs: &Operand<'tcx>,
341+
rhs: ImmTy<'tcx>,
336342
state: &mut State<ConditionSet<'a>>,
337343
) -> Option<!> {
338344
let register_opportunity = |c: Condition| {
339345
debug!(?bb, ?c.target, "register");
340346
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
341347
};
342348

349+
let conditions = state.try_get_idx(lhs, self.map)?;
350+
if let Immediate::Scalar(Scalar::Int(int)) = *rhs {
351+
conditions.iter_matches(int).for_each(register_opportunity);
352+
}
353+
354+
None
355+
}
356+
357+
#[instrument(level = "trace", skip(self))]
358+
fn process_operand(
359+
&mut self,
360+
bb: BasicBlock,
361+
lhs: PlaceIndex,
362+
rhs: &Operand<'tcx>,
363+
state: &mut State<ConditionSet<'a>>,
364+
) -> Option<!> {
343365
match rhs {
344366
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
345367
Operand::Constant(constant) => {
346-
let conditions = state.try_get_idx(lhs, self.map)?;
347-
let constant =
348-
constant.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
349-
conditions.iter_matches(constant).for_each(register_opportunity);
368+
let constant = self.ecx.eval_mir_constant(&constant.const_, None, None).ok()?;
369+
self.map.for_each_projection_value(
370+
lhs,
371+
constant,
372+
&mut |elem, op| match elem {
373+
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
374+
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
375+
TrackElem::Discriminant => {
376+
let variant = self.ecx.read_discriminant(op).ok()?;
377+
let discr_value =
378+
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
379+
Some(discr_value.into())
380+
}
381+
TrackElem::DerefLen => {
382+
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
383+
let len_usize = op.len(&self.ecx).ok()?;
384+
let layout = self
385+
.tcx
386+
.layout_of(self.param_env.and(self.tcx.types.usize))
387+
.unwrap();
388+
Some(ImmTy::from_uint(len_usize, layout).into())
389+
}
390+
},
391+
&mut |place, op| {
392+
if let Some(conditions) = state.try_get_idx(place, self.map)
393+
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
394+
&& let Some(imm) = imm.right()
395+
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
396+
{
397+
conditions.iter_matches(int).for_each(|c: Condition| {
398+
self.opportunities.push(ThreadingOpportunity {
399+
chain: vec![bb],
400+
target: c.target,
401+
})
402+
})
403+
}
404+
},
405+
);
350406
}
351407
// Transfer the conditions on the copied rhs.
352408
Operand::Move(rhs) | Operand::Copy(rhs) => {
@@ -373,26 +429,14 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
373429
// Below, `lhs` is the return value of `mutated_statement`,
374430
// the place to which `conditions` apply.
375431

376-
let discriminant_for_variant = |enum_ty: Ty<'tcx>, variant_index| {
377-
let discr = enum_ty.discriminant_for_variant(self.tcx, variant_index)?;
378-
let discr_layout = self.tcx.layout_of(self.param_env.and(discr.ty)).ok()?;
379-
let scalar = ScalarInt::try_from_uint(discr.val, discr_layout.size)?;
380-
Some(Operand::const_from_scalar(
381-
self.tcx,
382-
discr.ty,
383-
scalar.into(),
384-
rustc_span::DUMMY_SP,
385-
))
386-
};
387-
388432
match &stmt.kind {
389433
// If we expect `discriminant(place) ?= A`,
390434
// we have an opportunity if `variant_index ?= A`.
391435
StatementKind::SetDiscriminant { box place, variant_index } => {
392436
let discr_target = self.map.find_discr(place.as_ref())?;
393437
let enum_ty = place.ty(self.body, self.tcx).ty;
394-
let discr = discriminant_for_variant(enum_ty, *variant_index)?;
395-
self.process_operand(bb, discr_target, &discr, state)?;
438+
let discr = self.ecx.discriminant_for_variant(enum_ty, *variant_index).ok()?;
439+
self.process_immediate(bb, discr_target, discr, state)?;
396440
}
397441
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
398442
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
@@ -422,10 +466,16 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
422466
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
423467
if let Some(discr_target) =
424468
self.map.apply(lhs, TrackElem::Discriminant)
425-
&& let Some(discr_value) =
426-
discriminant_for_variant(agg_ty, *variant_index)
469+
&& let Ok(discr_value) = self
470+
.ecx
471+
.discriminant_for_variant(agg_ty, *variant_index)
427472
{
428-
self.process_operand(bb, discr_target, &discr_value, state);
473+
self.process_immediate(
474+
bb,
475+
discr_target,
476+
discr_value,
477+
state,
478+
);
429479
}
430480
self.map.apply(lhs, TrackElem::Variant(*variant_index))?
431481
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
- // MIR for `aggregate` before JumpThreading
2+
+ // MIR for `aggregate` after JumpThreading
3+
4+
fn aggregate(_1: u8) -> u8 {
5+
debug x => _1;
6+
let mut _0: u8;
7+
let _2: u8;
8+
let _3: u8;
9+
let mut _4: (u8, u8);
10+
let mut _5: bool;
11+
let mut _6: u8;
12+
scope 1 {
13+
debug a => _2;
14+
debug b => _3;
15+
}
16+
17+
bb0: {
18+
StorageLive(_4);
19+
_4 = const _;
20+
StorageLive(_2);
21+
_2 = (_4.0: u8);
22+
StorageLive(_3);
23+
_3 = (_4.1: u8);
24+
StorageDead(_4);
25+
StorageLive(_5);
26+
StorageLive(_6);
27+
_6 = _2;
28+
_5 = Eq(move _6, const 7_u8);
29+
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
30+
+ goto -> bb2;
31+
}
32+
33+
bb1: {
34+
StorageDead(_6);
35+
_0 = _3;
36+
goto -> bb3;
37+
}
38+
39+
bb2: {
40+
StorageDead(_6);
41+
_0 = _2;
42+
goto -> bb3;
43+
}
44+
45+
bb3: {
46+
StorageDead(_5);
47+
StorageDead(_3);
48+
StorageDead(_2);
49+
return;
50+
}
51+
}
52+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
- // MIR for `aggregate` before JumpThreading
2+
+ // MIR for `aggregate` after JumpThreading
3+
4+
fn aggregate(_1: u8) -> u8 {
5+
debug x => _1;
6+
let mut _0: u8;
7+
let _2: u8;
8+
let _3: u8;
9+
let mut _4: (u8, u8);
10+
let mut _5: bool;
11+
let mut _6: u8;
12+
scope 1 {
13+
debug a => _2;
14+
debug b => _3;
15+
}
16+
17+
bb0: {
18+
StorageLive(_4);
19+
_4 = const _;
20+
StorageLive(_2);
21+
_2 = (_4.0: u8);
22+
StorageLive(_3);
23+
_3 = (_4.1: u8);
24+
StorageDead(_4);
25+
StorageLive(_5);
26+
StorageLive(_6);
27+
_6 = _2;
28+
_5 = Eq(move _6, const 7_u8);
29+
- switchInt(move _5) -> [0: bb2, otherwise: bb1];
30+
+ goto -> bb2;
31+
}
32+
33+
bb1: {
34+
StorageDead(_6);
35+
_0 = _3;
36+
goto -> bb3;
37+
}
38+
39+
bb2: {
40+
StorageDead(_6);
41+
_0 = _2;
42+
goto -> bb3;
43+
}
44+
45+
bb3: {
46+
StorageDead(_5);
47+
StorageDead(_3);
48+
StorageDead(_2);
49+
return;
50+
}
51+
}
52+

tests/mir-opt/jump_threading.rs

+18
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,23 @@ fn disappearing_bb(x: u8) -> u8 {
455455
)
456456
}
457457

458+
/// Verify that we can thread jumps when the constant comes from a static.
459+
fn aggregate(x: u8) -> u8 {
460+
// CHECK-LABEL: fn aggregate(
461+
// CHECK-NOT: switchInt(
462+
463+
const FOO: (u8, u8) = (5, 13);
464+
465+
let (a, b) = FOO;
466+
if a == 7 {
467+
b
468+
} else {
469+
a
470+
}
471+
}
472+
458473
fn main() {
474+
// CHECK-LABEL: fn main(
459475
too_complex(Ok(0));
460476
identity(Ok(0));
461477
custom_discr(false);
@@ -466,6 +482,7 @@ fn main() {
466482
mutable_ref();
467483
renumbered_bb(true);
468484
disappearing_bb(7);
485+
aggregate(7);
469486
}
470487

471488
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@@ -478,3 +495,4 @@ fn main() {
478495
// EMIT_MIR jump_threading.mutable_ref.JumpThreading.diff
479496
// EMIT_MIR jump_threading.renumbered_bb.JumpThreading.diff
480497
// EMIT_MIR jump_threading.disappearing_bb.JumpThreading.diff
498+
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff

0 commit comments

Comments
 (0)