Skip to content

Commit d417c8b

Browse files
committed
Make the enum check work for negative discriminants
The discriminant check was not working correctly for negative numbers. This change fixes that by masking out the relevant bits correctly.
1 parent ad3b725 commit d417c8b

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ enum EnumCheckType<'tcx> {
120120
},
121121
}
122122

123+
#[derive(Debug, Copy, Clone)]
123124
struct TyAndSize<'tcx> {
124125
pub ty: Ty<'tcx>,
125126
pub size: Size,
@@ -338,7 +339,7 @@ fn insert_direct_enum_check<'tcx>(
338339
let invalid_discr_block_data = BasicBlockData::new(None, false);
339340
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340341
let block_data = &mut basic_blocks[current_block];
341-
let discr = insert_discr_cast_to_u128(
342+
let discr_place = insert_discr_cast_to_u128(
342343
tcx,
343344
local_decls,
344345
block_data,
@@ -349,13 +350,41 @@ fn insert_direct_enum_check<'tcx>(
349350
source_info,
350351
);
351352

353+
// Mask out the bits of the discriminant type.
354+
let mask = match discr.size.bytes() {
355+
1 => u8::MAX as u128,
356+
2 => u16::MAX as u128,
357+
4 => u32::MAX as u128,
358+
8 => u64::MAX as u128,
359+
16 => u128::MAX as u128,
360+
invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
361+
};
362+
363+
let discr_masked =
364+
local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
365+
let rvalue = Rvalue::BinaryOp(
366+
BinOp::BitAnd,
367+
Box::new((
368+
Operand::Copy(discr_place),
369+
Operand::Constant(Box::new(ConstOperand {
370+
span: source_info.span,
371+
user_ty: None,
372+
const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
373+
})),
374+
)),
375+
);
376+
block_data.statements.push(Statement {
377+
source_info,
378+
kind: StatementKind::Assign(Box::new((discr_masked, rvalue))),
379+
});
380+
352381
// Branch based on the discriminant value.
353382
block_data.terminator = Some(Terminator {
354383
source_info,
355384
kind: TerminatorKind::SwitchInt {
356-
discr: Operand::Copy(discr),
385+
discr: Operand::Copy(discr_masked),
357386
targets: SwitchTargets::new(
358-
discriminants.into_iter().map(|discr| (discr, new_block)),
387+
discriminants.into_iter().map(|discr| (discr & mask, new_block)),
359388
invalid_discr_block,
360389
),
361390
},
@@ -372,7 +401,7 @@ fn insert_direct_enum_check<'tcx>(
372401
})),
373402
expected: true,
374403
target: new_block,
375-
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
404+
msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
376405
// This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
377406
// We never want to insert an unwind into unsafe code, because unwinding could
378407
// make a failing UB check turn into much worse UB when we start unwinding.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
enum Foo {
6+
A = -2,
7+
B = 1,
8+
}
9+
10+
fn main() {
11+
let _val: Foo =
12+
unsafe { std::mem::transmute::<i8, Foo>(-2) };
13+
let _val: Foo =
14+
unsafe { std::mem::transmute::<i8, Foo>(1) };
15+
}

0 commit comments

Comments
 (0)