Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 66e9a6f

Browse files
committed
Auto merge of rust-lang#116042 - Nadrieril:linear-pass-take-2, r=<try>
[Experiment] Rewrite exhaustiveness in one pass Arm reachability checking does a quadratic amount of work: for each arm we check if it is reachable given the arms above it. This feels wasteful since we often end up re-exploring the same cases when we check for exhaustiveness. This PR is an attempt to check reachability at the same time as exhaustiveness. This opens the door to a bunch of code simplifications I'm very excited about. The main question is whether I can get actual performance gains out of this. I had started the experiment in rust-lang#111720 but I can't reopen it. r? `@ghost`
2 parents 3ff244b + 9575d53 commit 66e9a6f

File tree

12 files changed

+1055
-796
lines changed

12 files changed

+1055
-796
lines changed

compiler/rustc_middle/src/mir/consts.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,15 @@ impl<'tcx> Const<'tcx> {
297297
tcx: TyCtxt<'tcx>,
298298
param_env: ty::ParamEnv<'tcx>,
299299
) -> Option<ScalarInt> {
300-
self.try_eval_scalar(tcx, param_env)?.try_to_int().ok()
300+
match self {
301+
// Fast path for already evaluated constants.
302+
Const::Val(ConstValue::Scalar(Scalar::Int(scalar_int)), _) => Some(scalar_int),
303+
Const::Ty(c)
304+
if let ty::ConstKind::Value(ty::ValTree::Leaf(scalar_int)) = c.kind()
305+
&& c.ty().is_primitive()
306+
=> Some(scalar_int),
307+
_ => self.try_eval_scalar(tcx, param_env)?.try_to_int().ok(),
308+
}
301309
}
302310

303311
#[inline]

compiler/rustc_middle/src/thir.rs

Lines changed: 234 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ use rustc_hir::RangeEnd;
1616
use rustc_index::newtype_index;
1717
use rustc_index::IndexVec;
1818
use rustc_middle::middle::region;
19-
use rustc_middle::mir::interpret::AllocId;
19+
use rustc_middle::mir::interpret::{AllocId, Scalar};
2020
use rustc_middle::mir::{self, BinOp, BorrowKind, FakeReadCause, Mutability, UnOp};
2121
use rustc_middle::ty::adjustment::PointerCoercion;
22+
use rustc_middle::ty::layout::IntegerExt;
2223
use rustc_middle::ty::GenericArgsRef;
23-
use rustc_middle::ty::{self, AdtDef, FnSig, List, Ty, UpvarArgs};
24+
use rustc_middle::ty::{self, AdtDef, FnSig, List, Ty, TyCtxt, UpvarArgs};
2425
use rustc_middle::ty::{CanonicalUserType, CanonicalUserTypeAnnotation};
2526
use rustc_span::def_id::LocalDefId;
2627
use rustc_span::{sym, Span, Symbol, DUMMY_SP};
27-
use rustc_target::abi::{FieldIdx, VariantIdx};
28+
use rustc_target::abi::{FieldIdx, Integer, Size, VariantIdx};
2829
use rustc_target::asm::InlineAsmRegOrRegClass;
30+
use std::cmp::Ordering;
2931
use std::fmt;
3032
use std::ops::Index;
3133

@@ -773,12 +775,238 @@ pub enum PatKind<'tcx> {
773775
},
774776
}
775777

778+
/// A range pattern.
779+
/// The boundaries must be of the same type and that type must be numeric.
776780
#[derive(Clone, Debug, PartialEq, HashStable, TypeVisitable)]
777781
pub struct PatRange<'tcx> {
778-
pub lo: mir::Const<'tcx>,
779-
pub hi: mir::Const<'tcx>,
782+
pub lo: PatRangeBoundary<'tcx>,
783+
pub hi: PatRangeBoundary<'tcx>,
780784
#[type_visitable(ignore)]
781785
pub end: RangeEnd,
786+
pub ty: Ty<'tcx>,
787+
}
788+
789+
impl<'tcx> PatRange<'tcx> {
790+
/// Whether this range covers the full extent of possible values (best-effort, we ignore floats).
791+
#[inline]
792+
pub fn is_full_range(&self, tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>) -> Option<bool> {
793+
let lo = self.lo.to_const(self.ty, tcx, param_env);
794+
let hi = self.hi.to_const(self.ty, tcx, param_env);
795+
796+
let (min, max, size, bias) = match *self.ty.kind() {
797+
ty::Char => (0, std::char::MAX as u128, Size::from_bits(32), 0),
798+
ty::Int(ity) => {
799+
let size = Integer::from_int_ty(&tcx, ity).size();
800+
let max = size.truncate(u128::MAX);
801+
let bias = 1u128 << (size.bits() - 1);
802+
(0, max, size, bias)
803+
}
804+
ty::Uint(uty) => {
805+
let size = Integer::from_uint_ty(&tcx, uty).size();
806+
let max = size.unsigned_int_max();
807+
(0, max, size, 0)
808+
}
809+
_ => return None,
810+
};
811+
// We want to compare ranges numerically, but the order of the bitwise representation of
812+
// signed integers does not match their numeric order. Thus, to correct the ordering, we
813+
// need to shift the range of signed integers to correct the comparison. This is achieved by
814+
// XORing with a bias (see pattern/deconstruct_pat.rs for another pertinent example of this
815+
// pattern).
816+
//
817+
// Also, for performance, it's important to only do the second `try_to_bits` if necessary.
818+
let lo = lo.try_to_bits(size).unwrap() ^ bias;
819+
if lo <= min {
820+
let hi = hi.try_to_bits(size).unwrap() ^ bias;
821+
if hi > max || hi == max && self.end == RangeEnd::Included {
822+
return Some(true);
823+
}
824+
}
825+
Some(false)
826+
}
827+
828+
#[inline]
829+
pub fn contains(
830+
&self,
831+
value: mir::Const<'tcx>,
832+
tcx: TyCtxt<'tcx>,
833+
param_env: ty::ParamEnv<'tcx>,
834+
) -> Option<bool> {
835+
use Ordering::*;
836+
debug_assert_eq!(self.ty, value.ty());
837+
let ty = self.ty;
838+
let value = PatRangeBoundary::new_finite(value, tcx, param_env);
839+
// For performance, it's important to only do the second comparison if necessary.
840+
Some(
841+
match self.lo.compare_with(value, ty, tcx, param_env)? {
842+
Less | Equal => true,
843+
Greater => false,
844+
} && match value.compare_with(self.hi, ty, tcx, param_env)? {
845+
Less => true,
846+
Equal => self.end == RangeEnd::Included,
847+
Greater => false,
848+
},
849+
)
850+
}
851+
852+
#[inline]
853+
pub fn overlaps(
854+
&self,
855+
other: &Self,
856+
tcx: TyCtxt<'tcx>,
857+
param_env: ty::ParamEnv<'tcx>,
858+
) -> Option<bool> {
859+
use Ordering::*;
860+
debug_assert_eq!(self.ty, other.ty);
861+
// For performance, it's important to only do the second comparison if necessary.
862+
Some(
863+
match other.lo.compare_with(self.hi, self.ty, tcx, param_env)? {
864+
Less => true,
865+
Equal => self.end == RangeEnd::Included,
866+
Greater => false,
867+
} && match self.lo.compare_with(other.hi, self.ty, tcx, param_env)? {
868+
Less => true,
869+
Equal => other.end == RangeEnd::Included,
870+
Greater => false,
871+
},
872+
)
873+
}
874+
}
875+
876+
impl<'tcx> fmt::Display for PatRange<'tcx> {
877+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
878+
if let PatRangeBoundary::Finite { value, .. } = &self.lo {
879+
write!(f, "{value}")?;
880+
}
881+
write!(f, "{}", self.end)?;
882+
if let PatRangeBoundary::Finite { value, .. } = &self.hi {
883+
write!(f, "{value}")?;
884+
}
885+
Ok(())
886+
}
887+
}
888+
889+
/// A (possibly open) boundary of a range pattern.
890+
/// If present, the const must be of a numeric type.
891+
#[derive(Copy, Clone, Debug, PartialEq, HashStable, TypeVisitable)]
892+
pub enum PatRangeBoundary<'tcx> {
893+
Finite { value: mir::Const<'tcx> },
894+
// PosInfinity,
895+
// NegInfinity,
896+
}
897+
898+
impl<'tcx> PatRangeBoundary<'tcx> {
899+
#[inline]
900+
pub fn new_finite(
901+
value: mir::Const<'tcx>,
902+
_tcx: TyCtxt<'tcx>,
903+
_param_env: ty::ParamEnv<'tcx>,
904+
) -> Self {
905+
Self::Finite { value }
906+
}
907+
#[inline]
908+
pub fn lower_bound(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Self {
909+
// Self::NegInfinity
910+
// Unwrap is ok because the type is known to be numeric.
911+
let c = ty.numeric_min_val(tcx).unwrap();
912+
let value = mir::Const::from_ty_const(c, tcx);
913+
Self::Finite { value }
914+
}
915+
#[inline]
916+
pub fn upper_bound(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Self {
917+
// Self::PosInfinity
918+
// Unwrap is ok because the type is known to be numeric.
919+
let c = ty.numeric_max_val(tcx).unwrap();
920+
let value = mir::Const::from_ty_const(c, tcx);
921+
Self::Finite { value }
922+
}
923+
924+
#[inline]
925+
pub fn to_const(
926+
self,
927+
_ty: Ty<'tcx>,
928+
_tcx: TyCtxt<'tcx>,
929+
_param_env: ty::ParamEnv<'tcx>,
930+
) -> mir::Const<'tcx> {
931+
match self {
932+
Self::Finite { value } => value,
933+
// Self::PosInfinity | Self::NegInfinity => unreachable!(),
934+
}
935+
}
936+
#[inline]
937+
pub fn eval_bits(
938+
self,
939+
_ty: Ty<'tcx>,
940+
tcx: TyCtxt<'tcx>,
941+
param_env: ty::ParamEnv<'tcx>,
942+
) -> u128 {
943+
match self {
944+
Self::Finite { value } => value.eval_bits(tcx, param_env),
945+
// Self::NegInfinity => {
946+
// // Unwrap is ok because the type is known to be numeric.
947+
// ty.numeric_min_val_as_bits(tcx).unwrap()
948+
// }
949+
// Self::PosInfinity => {
950+
// // Unwrap is ok because the type is known to be numeric.
951+
// ty.numeric_max_val_as_bits(tcx).unwrap()
952+
// }
953+
}
954+
}
955+
956+
#[instrument(skip(tcx), level = "debug")]
957+
#[inline]
958+
pub fn compare_with(
959+
self,
960+
other: Self,
961+
ty: Ty<'tcx>,
962+
tcx: TyCtxt<'tcx>,
963+
param_env: ty::ParamEnv<'tcx>,
964+
) -> Option<Ordering> {
965+
use PatRangeBoundary::*;
966+
match (self, other) {
967+
// (PosInfinity, PosInfinity) => return Some(Ordering::Equal),
968+
// (NegInfinity, NegInfinity) => return Some(Ordering::Equal),
969+
970+
// This code is hot when compiling matches with many ranges. So we
971+
// special-case extraction of evaluated scalars for speed, for types where
972+
// raw data comparisons are appropriate. E.g. `unicode-normalization` has
973+
// many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared
974+
// in this way.
975+
(Finite { value: mir::Const::Ty(a) }, Finite { value: mir::Const::Ty(b) })
976+
if matches!(ty.kind(), ty::Uint(_) | ty::Char) =>
977+
{
978+
return Some(a.kind().cmp(&b.kind()));
979+
}
980+
_ => {}
981+
}
982+
983+
let a = self.eval_bits(ty, tcx, param_env);
984+
let b = other.eval_bits(ty, tcx, param_env);
985+
986+
match ty.kind() {
987+
ty::Float(ty::FloatTy::F32) => {
988+
use rustc_apfloat::Float;
989+
let a = rustc_apfloat::ieee::Single::from_bits(a);
990+
let b = rustc_apfloat::ieee::Single::from_bits(b);
991+
a.partial_cmp(&b)
992+
}
993+
ty::Float(ty::FloatTy::F64) => {
994+
use rustc_apfloat::Float;
995+
let a = rustc_apfloat::ieee::Double::from_bits(a);
996+
let b = rustc_apfloat::ieee::Double::from_bits(b);
997+
a.partial_cmp(&b)
998+
}
999+
ty::Int(ity) => {
1000+
use rustc_middle::ty::layout::IntegerExt;
1001+
let size = rustc_target::abi::Integer::from_int_ty(&tcx, *ity).size();
1002+
let a = size.sign_extend(a) as i128;
1003+
let b = size.sign_extend(b) as i128;
1004+
Some(a.cmp(&b))
1005+
}
1006+
ty::Uint(_) | ty::Char => Some(a.cmp(&b)),
1007+
_ => bug!(),
1008+
}
1009+
}
7821010
}
7831011

7841012
impl<'tcx> fmt::Display for Pat<'tcx> {
@@ -904,11 +1132,7 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
9041132
write!(f, "{subpattern}")
9051133
}
9061134
PatKind::Constant { value } => write!(f, "{value}"),
907-
PatKind::Range(box PatRange { lo, hi, end }) => {
908-
write!(f, "{lo}")?;
909-
write!(f, "{end}")?;
910-
write!(f, "{hi}")
911-
}
1135+
PatKind::Range(ref range) => write!(f, "{range}"),
9121136
PatKind::Slice { ref prefix, ref slice, ref suffix }
9131137
| PatKind::Array { ref prefix, ref slice, ref suffix } => {
9141138
write!(f, "[")?;

compiler/rustc_middle/src/ty/util.rs

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use rustc_index::bit_set::GrowableBitSet;
1919
use rustc_macros::HashStable;
2020
use rustc_session::Limit;
2121
use rustc_span::sym;
22-
use rustc_target::abi::{Integer, IntegerType, Size};
22+
use rustc_target::abi::{Integer, IntegerType, Primitive, Size};
2323
use rustc_target::spec::abi::Abi;
2424
use smallvec::SmallVec;
2525
use std::{fmt, iter};
@@ -917,54 +917,62 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for OpaqueTypeExpander<'tcx> {
917917
}
918918

919919
impl<'tcx> Ty<'tcx> {
920+
/// Returns the `Size` for primitive types (bool, uint, int, char, float).
921+
pub fn primitive_size(self, tcx: TyCtxt<'tcx>) -> Size {
922+
match *self.kind() {
923+
ty::Bool => Size::from_bytes(1),
924+
ty::Char => Size::from_bytes(4),
925+
ty::Int(ity) => Integer::from_int_ty(&tcx, ity).size(),
926+
ty::Uint(uty) => Integer::from_uint_ty(&tcx, uty).size(),
927+
ty::Float(ty::FloatTy::F32) => Primitive::F32.size(&tcx),
928+
ty::Float(ty::FloatTy::F64) => Primitive::F64.size(&tcx),
929+
_ => bug!("non primitive type"),
930+
}
931+
}
932+
920933
pub fn int_size_and_signed(self, tcx: TyCtxt<'tcx>) -> (Size, bool) {
921-
let (int, signed) = match *self.kind() {
922-
ty::Int(ity) => (Integer::from_int_ty(&tcx, ity), true),
923-
ty::Uint(uty) => (Integer::from_uint_ty(&tcx, uty), false),
934+
match *self.kind() {
935+
ty::Int(ity) => (Integer::from_int_ty(&tcx, ity).size(), true),
936+
ty::Uint(uty) => (Integer::from_uint_ty(&tcx, uty).size(), false),
924937
_ => bug!("non integer discriminant"),
925-
};
926-
(int.size(), signed)
938+
}
927939
}
928940

929-
/// Returns the maximum value for the given numeric type (including `char`s)
930-
/// or returns `None` if the type is not numeric.
931-
pub fn numeric_max_val(self, tcx: TyCtxt<'tcx>) -> Option<ty::Const<'tcx>> {
932-
let val = match self.kind() {
941+
/// Returns the minimum and maximum values for the given numeric type (including `char`s) or
942+
/// returns `None` if the type is not numeric.
943+
pub fn numeric_min_and_max_as_bits(self, tcx: TyCtxt<'tcx>) -> Option<(u128, u128)> {
944+
use rustc_apfloat::ieee::{Double, Single};
945+
Some(match self.kind() {
933946
ty::Int(_) | ty::Uint(_) => {
934947
let (size, signed) = self.int_size_and_signed(tcx);
935-
let val =
948+
let min = if signed { size.truncate(size.signed_int_min() as u128) } else { 0 };
949+
let max =
936950
if signed { size.signed_int_max() as u128 } else { size.unsigned_int_max() };
937-
Some(val)
951+
(min, max)
938952
}
939-
ty::Char => Some(std::char::MAX as u128),
940-
ty::Float(fty) => Some(match fty {
941-
ty::FloatTy::F32 => rustc_apfloat::ieee::Single::INFINITY.to_bits(),
942-
ty::FloatTy::F64 => rustc_apfloat::ieee::Double::INFINITY.to_bits(),
943-
}),
944-
_ => None,
945-
};
953+
ty::Char => (0, std::char::MAX as u128),
954+
ty::Float(ty::FloatTy::F32) => {
955+
((-Single::INFINITY).to_bits(), Single::INFINITY.to_bits())
956+
}
957+
ty::Float(ty::FloatTy::F64) => {
958+
((-Double::INFINITY).to_bits(), Double::INFINITY.to_bits())
959+
}
960+
_ => return None,
961+
})
962+
}
946963

947-
val.map(|v| ty::Const::from_bits(tcx, v, ty::ParamEnv::empty().and(self)))
964+
/// Returns the maximum value for the given numeric type (including `char`s)
965+
/// or returns `None` if the type is not numeric.
966+
pub fn numeric_max_val(self, tcx: TyCtxt<'tcx>) -> Option<ty::Const<'tcx>> {
967+
self.numeric_min_and_max_as_bits(tcx)
968+
.map(|(_, max)| ty::Const::from_bits(tcx, max, ty::ParamEnv::empty().and(self)))
948969
}
949970

950971
/// Returns the minimum value for the given numeric type (including `char`s)
951972
/// or returns `None` if the type is not numeric.
952973
pub fn numeric_min_val(self, tcx: TyCtxt<'tcx>) -> Option<ty::Const<'tcx>> {
953-
let val = match self.kind() {
954-
ty::Int(_) | ty::Uint(_) => {
955-
let (size, signed) = self.int_size_and_signed(tcx);
956-
let val = if signed { size.truncate(size.signed_int_min() as u128) } else { 0 };
957-
Some(val)
958-
}
959-
ty::Char => Some(0),
960-
ty::Float(fty) => Some(match fty {
961-
ty::FloatTy::F32 => (-::rustc_apfloat::ieee::Single::INFINITY).to_bits(),
962-
ty::FloatTy::F64 => (-::rustc_apfloat::ieee::Double::INFINITY).to_bits(),
963-
}),
964-
_ => None,
965-
};
966-
967-
val.map(|v| ty::Const::from_bits(tcx, v, ty::ParamEnv::empty().and(self)))
974+
self.numeric_min_and_max_as_bits(tcx)
975+
.map(|(min, _)| ty::Const::from_bits(tcx, min, ty::ParamEnv::empty().and(self)))
968976
}
969977

970978
/// Checks whether values of this type `T` are *moved* or *copied*

0 commit comments

Comments
 (0)