Skip to content

Commit 5d56e1d

Browse files
committed
Specialize equality for [T] and comparison for [u8]
Where T is a type that can be compared for equality bytewise, we can use memcmp. We can also use memcmp for PartialOrd, Ord for [u8] and by extension &str. This is an improvement for example for the comparison [u8] == [u8] that used to emit a loop that compared the slices byte by byte. One worry here could be that this introduces function calls to memcmp in contexts where it should really inline the comparison or even optimize it out, but llvm takes care of recognizing memcmp specifically.
1 parent a09f386 commit 5d56e1d

File tree

3 files changed

+139
-38
lines changed

3 files changed

+139
-38
lines changed

src/libcore/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
#![feature(unwind_attributes)]
7676
#![feature(repr_simd, platform_intrinsics)]
7777
#![feature(rustc_attrs)]
78+
#![feature(specialization)]
7879
#![feature(staged_api)]
7980
#![feature(unboxed_closures)]
8081
#![feature(question_mark)]

src/libcore/slice.rs

+135-16
Original file line numberDiff line numberDiff line change
@@ -1630,12 +1630,59 @@ pub unsafe fn from_raw_parts_mut<'a, T>(p: *mut T, len: usize) -> &'a mut [T] {
16301630
}
16311631

16321632
//
1633-
// Boilerplate traits
1633+
// Comparison traits
16341634
//
16351635

1636+
extern {
1637+
/// Call implementation provided memcmp
1638+
///
1639+
/// Interprets the data as u8.
1640+
///
1641+
/// Return 0 for equal, < 0 for less than and > 0 for greater
1642+
/// than.
1643+
// FIXME(#32610): Return type should be c_int
1644+
fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32;
1645+
}
1646+
16361647
#[stable(feature = "rust1", since = "1.0.0")]
16371648
impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
16381649
fn eq(&self, other: &[B]) -> bool {
1650+
SlicePartialEq::equal(self, other)
1651+
}
1652+
1653+
fn ne(&self, other: &[B]) -> bool {
1654+
SlicePartialEq::not_equal(self, other)
1655+
}
1656+
}
1657+
1658+
#[stable(feature = "rust1", since = "1.0.0")]
1659+
impl<T: Eq> Eq for [T] {}
1660+
1661+
#[stable(feature = "rust1", since = "1.0.0")]
1662+
impl<T: Ord> Ord for [T] {
1663+
fn cmp(&self, other: &[T]) -> Ordering {
1664+
SliceOrd::compare(self, other)
1665+
}
1666+
}
1667+
1668+
#[stable(feature = "rust1", since = "1.0.0")]
1669+
impl<T: PartialOrd> PartialOrd for [T] {
1670+
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
1671+
SlicePartialOrd::partial_compare(self, other)
1672+
}
1673+
}
1674+
1675+
// intermediate trait for specialization of slice's PartialEq
1676+
trait SlicePartialEq<B> {
1677+
fn equal(&self, other: &[B]) -> bool;
1678+
fn not_equal(&self, other: &[B]) -> bool;
1679+
}
1680+
1681+
// Generic slice equality
1682+
impl<A, B> SlicePartialEq<B> for [A]
1683+
where A: PartialEq<B>
1684+
{
1685+
default fn equal(&self, other: &[B]) -> bool {
16391686
if self.len() != other.len() {
16401687
return false;
16411688
}
@@ -1648,7 +1695,8 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
16481695

16491696
true
16501697
}
1651-
fn ne(&self, other: &[B]) -> bool {
1698+
1699+
default fn not_equal(&self, other: &[B]) -> bool {
16521700
if self.len() != other.len() {
16531701
return true;
16541702
}
@@ -1663,12 +1711,35 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
16631711
}
16641712
}
16651713

1666-
#[stable(feature = "rust1", since = "1.0.0")]
1667-
impl<T: Eq> Eq for [T] {}
1714+
// Use memcmp for bytewise equality when the types allow
1715+
impl<A> SlicePartialEq<A> for [A]
1716+
where A: PartialEq<A> + BytewiseEquality
1717+
{
1718+
fn equal(&self, other: &[A]) -> bool {
1719+
if self.len() != other.len() {
1720+
return false;
1721+
}
1722+
unsafe {
1723+
let size = mem::size_of_val(self);
1724+
memcmp(self.as_ptr() as *const u8,
1725+
other.as_ptr() as *const u8, size) == 0
1726+
}
1727+
}
16681728

1669-
#[stable(feature = "rust1", since = "1.0.0")]
1670-
impl<T: Ord> Ord for [T] {
1671-
fn cmp(&self, other: &[T]) -> Ordering {
1729+
fn not_equal(&self, other: &[A]) -> bool {
1730+
!self.equal(other)
1731+
}
1732+
}
1733+
1734+
// intermediate trait for specialization of slice's PartialOrd
1735+
trait SlicePartialOrd<B> {
1736+
fn partial_compare(&self, other: &[B]) -> Option<Ordering>;
1737+
}
1738+
1739+
impl<A> SlicePartialOrd<A> for [A]
1740+
where A: PartialOrd
1741+
{
1742+
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
16721743
let l = cmp::min(self.len(), other.len());
16731744

16741745
// Slice to the loop iteration range to enable bound check
@@ -1677,19 +1748,32 @@ impl<T: Ord> Ord for [T] {
16771748
let rhs = &other[..l];
16781749

16791750
for i in 0..l {
1680-
match lhs[i].cmp(&rhs[i]) {
1681-
Ordering::Equal => (),
1751+
match lhs[i].partial_cmp(&rhs[i]) {
1752+
Some(Ordering::Equal) => (),
16821753
non_eq => return non_eq,
16831754
}
16841755
}
16851756

1686-
self.len().cmp(&other.len())
1757+
self.len().partial_cmp(&other.len())
16871758
}
16881759
}
16891760

1690-
#[stable(feature = "rust1", since = "1.0.0")]
1691-
impl<T: PartialOrd> PartialOrd for [T] {
1692-
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
1761+
impl SlicePartialOrd<u8> for [u8] {
1762+
#[inline]
1763+
fn partial_compare(&self, other: &[u8]) -> Option<Ordering> {
1764+
Some(SliceOrd::compare(self, other))
1765+
}
1766+
}
1767+
1768+
// intermediate trait for specialization of slice's Ord
1769+
trait SliceOrd<B> {
1770+
fn compare(&self, other: &[B]) -> Ordering;
1771+
}
1772+
1773+
impl<A> SliceOrd<A> for [A]
1774+
where A: Ord
1775+
{
1776+
default fn compare(&self, other: &[A]) -> Ordering {
16931777
let l = cmp::min(self.len(), other.len());
16941778

16951779
// Slice to the loop iteration range to enable bound check
@@ -1698,12 +1782,47 @@ impl<T: PartialOrd> PartialOrd for [T] {
16981782
let rhs = &other[..l];
16991783

17001784
for i in 0..l {
1701-
match lhs[i].partial_cmp(&rhs[i]) {
1702-
Some(Ordering::Equal) => (),
1785+
match lhs[i].cmp(&rhs[i]) {
1786+
Ordering::Equal => (),
17031787
non_eq => return non_eq,
17041788
}
17051789
}
17061790

1707-
self.len().partial_cmp(&other.len())
1791+
self.len().cmp(&other.len())
1792+
}
1793+
}
1794+
1795+
// memcmp compares a sequence of unsigned bytes lexicographically.
1796+
// this matches the order we want for [u8], but no others (not even [i8]).
1797+
impl SliceOrd<u8> for [u8] {
1798+
#[inline]
1799+
fn compare(&self, other: &[u8]) -> Ordering {
1800+
let order = unsafe {
1801+
memcmp(self.as_ptr(), other.as_ptr(),
1802+
cmp::min(self.len(), other.len()))
1803+
};
1804+
if order == 0 {
1805+
self.len().cmp(&other.len())
1806+
} else if order < 0 {
1807+
Less
1808+
} else {
1809+
Greater
1810+
}
1811+
}
1812+
}
1813+
1814+
/// Trait implemented for types that can be compared for equality using
1815+
/// their bytewise representation
1816+
trait BytewiseEquality { }
1817+
1818+
macro_rules! impl_marker_for {
1819+
($traitname:ident, $($ty:ty)*) => {
1820+
$(
1821+
impl $traitname for $ty { }
1822+
)*
17081823
}
17091824
}
1825+
1826+
impl_marker_for!(BytewiseEquality,
1827+
u8 i8 u16 i16 u32 i32 u64 i64 usize isize char bool);
1828+

src/libcore/str/mod.rs

+3-22
Original file line numberDiff line numberDiff line change
@@ -1150,16 +1150,7 @@ Section: Comparing strings
11501150
#[lang = "str_eq"]
11511151
#[inline]
11521152
fn eq_slice(a: &str, b: &str) -> bool {
1153-
a.len() == b.len() && unsafe { cmp_slice(a, b, a.len()) == 0 }
1154-
}
1155-
1156-
/// Bytewise slice comparison.
1157-
/// NOTE: This uses the system's memcmp, which is currently dramatically
1158-
/// faster than comparing each byte in a loop.
1159-
#[inline]
1160-
unsafe fn cmp_slice(a: &str, b: &str, len: usize) -> i32 {
1161-
extern { fn memcmp(s1: *const i8, s2: *const i8, n: usize) -> i32; }
1162-
memcmp(a.as_ptr() as *const i8, b.as_ptr() as *const i8, len)
1153+
a.as_bytes() == b.as_bytes()
11631154
}
11641155

11651156
/*
@@ -1328,8 +1319,7 @@ Section: Trait implementations
13281319
*/
13291320

13301321
mod traits {
1331-
use cmp::{self, Ordering, Ord, PartialEq, PartialOrd, Eq};
1332-
use cmp::Ordering::{Less, Greater};
1322+
use cmp::{Ord, Ordering, PartialEq, PartialOrd, Eq};
13331323
use iter::Iterator;
13341324
use option::Option;
13351325
use option::Option::Some;
@@ -1340,16 +1330,7 @@ mod traits {
13401330
impl Ord for str {
13411331
#[inline]
13421332
fn cmp(&self, other: &str) -> Ordering {
1343-
let cmp = unsafe {
1344-
super::cmp_slice(self, other, cmp::min(self.len(), other.len()))
1345-
};
1346-
if cmp == 0 {
1347-
self.len().cmp(&other.len())
1348-
} else if cmp < 0 {
1349-
Less
1350-
} else {
1351-
Greater
1352-
}
1333+
self.as_bytes().cmp(other.as_bytes())
13531334
}
13541335
}
13551336

0 commit comments

Comments
 (0)