Skip to content

Commit 7953508

Browse files
committed
Revert "Auto merge of #103779 - the8472:simd-str-contains, r=thomcc"
The current implementation seems to be unsound. See #104726.
1 parent a78c9be commit 7953508

File tree

3 files changed

+12
-311
lines changed

3 files changed

+12
-311
lines changed

library/alloc/benches/str.rs

+7-58
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use core::iter::Iterator;
21
use test::{black_box, Bencher};
32

43
#[bench]
@@ -123,13 +122,14 @@ fn bench_contains_short_short(b: &mut Bencher) {
123122
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
124123
let needle = "sit";
125124

126-
b.bytes = haystack.len() as u64;
127125
b.iter(|| {
128-
assert!(black_box(haystack).contains(black_box(needle)));
126+
assert!(haystack.contains(needle));
129127
})
130128
}
131129

132-
static LONG_HAYSTACK: &str = "\
130+
#[bench]
131+
fn bench_contains_short_long(b: &mut Bencher) {
132+
let haystack = "\
133133
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse quis lorem sit amet dolor \
134134
ultricies condimentum. Praesent iaculis purus elit, ac malesuada quam malesuada in. Duis sed orci \
135135
eros. Suspendisse sit amet magna mollis, mollis nunc luctus, imperdiet mi. Integer fringilla non \
@@ -164,48 +164,10 @@ feugiat. Etiam quis mauris vel risus luctus mattis a a nunc. Nullam orci quam, i
164164
vehicula in, porttitor ut nibh. Duis sagittis adipiscing nisl vitae congue. Donec mollis risus eu \
165165
leo suscipit, varius porttitor nulla porta. Pellentesque ut sem nec nisi euismod vehicula. Nulla \
166166
malesuada sollicitudin quam eu fermentum.";
167-
168-
#[bench]
169-
fn bench_contains_2b_repeated_long(b: &mut Bencher) {
170-
let haystack = LONG_HAYSTACK;
171-
let needle = "::";
172-
173-
b.bytes = haystack.len() as u64;
174-
b.iter(|| {
175-
assert!(!black_box(haystack).contains(black_box(needle)));
176-
})
177-
}
178-
179-
#[bench]
180-
fn bench_contains_short_long(b: &mut Bencher) {
181-
let haystack = LONG_HAYSTACK;
182167
let needle = "english";
183168

184-
b.bytes = haystack.len() as u64;
185-
b.iter(|| {
186-
assert!(!black_box(haystack).contains(black_box(needle)));
187-
})
188-
}
189-
190-
#[bench]
191-
fn bench_contains_16b_in_long(b: &mut Bencher) {
192-
let haystack = LONG_HAYSTACK;
193-
let needle = "english language";
194-
195-
b.bytes = haystack.len() as u64;
196-
b.iter(|| {
197-
assert!(!black_box(haystack).contains(black_box(needle)));
198-
})
199-
}
200-
201-
#[bench]
202-
fn bench_contains_32b_in_long(b: &mut Bencher) {
203-
let haystack = LONG_HAYSTACK;
204-
let needle = "the english language sample text";
205-
206-
b.bytes = haystack.len() as u64;
207169
b.iter(|| {
208-
assert!(!black_box(haystack).contains(black_box(needle)));
170+
assert!(!haystack.contains(needle));
209171
})
210172
}
211173

@@ -214,20 +176,8 @@ fn bench_contains_bad_naive(b: &mut Bencher) {
214176
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
215177
let needle = "aaaaaaaab";
216178

217-
b.bytes = haystack.len() as u64;
218-
b.iter(|| {
219-
assert!(!black_box(haystack).contains(black_box(needle)));
220-
})
221-
}
222-
223-
#[bench]
224-
fn bench_contains_bad_simd(b: &mut Bencher) {
225-
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
226-
let needle = "aaabaaaa";
227-
228-
b.bytes = haystack.len() as u64;
229179
b.iter(|| {
230-
assert!(!black_box(haystack).contains(black_box(needle)));
180+
assert!(!haystack.contains(needle));
231181
})
232182
}
233183

@@ -236,9 +186,8 @@ fn bench_contains_equal(b: &mut Bencher) {
236186
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
237187
let needle = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
238188

239-
b.bytes = haystack.len() as u64;
240189
b.iter(|| {
241-
assert!(black_box(haystack).contains(black_box(needle)));
190+
assert!(haystack.contains(needle));
242191
})
243192
}
244193

library/alloc/tests/str.rs

+5-21
Original file line numberDiff line numberDiff line change
@@ -1590,27 +1590,11 @@ fn test_bool_from_str() {
15901590
assert_eq!("not even a boolean".parse::<bool>().ok(), None);
15911591
}
15921592

1593-
fn check_contains_all_substrings(haystack: &str) {
1594-
let mut modified_needle = String::new();
1595-
1596-
for i in 0..haystack.len() {
1597-
// check different haystack lengths since we special-case short haystacks.
1598-
let haystack = &haystack[0..i];
1599-
assert!(haystack.contains(""));
1600-
for j in 0..haystack.len() {
1601-
for k in j + 1..=haystack.len() {
1602-
let needle = &haystack[j..k];
1603-
assert!(haystack.contains(needle));
1604-
modified_needle.clear();
1605-
modified_needle.push_str(needle);
1606-
modified_needle.replace_range(0..1, "\0");
1607-
assert!(!haystack.contains(&modified_needle));
1608-
1609-
modified_needle.clear();
1610-
modified_needle.push_str(needle);
1611-
modified_needle.replace_range(needle.len() - 1..needle.len(), "\0");
1612-
assert!(!haystack.contains(&modified_needle));
1613-
}
1593+
fn check_contains_all_substrings(s: &str) {
1594+
assert!(s.contains(""));
1595+
for i in 0..s.len() {
1596+
for j in i + 1..=s.len() {
1597+
assert!(s.contains(&s[i..j]));
16141598
}
16151599
}
16161600
}

library/core/src/str/pattern.rs

-232
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
)]
4040

4141
use crate::cmp;
42-
use crate::cmp::Ordering;
4342
use crate::fmt;
4443
use crate::slice::memchr;
4544

@@ -947,32 +946,6 @@ impl<'a, 'b> Pattern<'a> for &'b str {
947946
haystack.as_bytes().starts_with(self.as_bytes())
948947
}
949948

950-
/// Checks whether the pattern matches anywhere in the haystack
951-
#[inline]
952-
fn is_contained_in(self, haystack: &'a str) -> bool {
953-
if self.len() == 0 {
954-
return true;
955-
}
956-
957-
match self.len().cmp(&haystack.len()) {
958-
Ordering::Less => {
959-
if self.len() == 1 {
960-
return haystack.as_bytes().contains(&self.as_bytes()[0]);
961-
}
962-
963-
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
964-
if self.len() <= 32 {
965-
if let Some(result) = simd_contains(self, haystack) {
966-
return result;
967-
}
968-
}
969-
970-
self.into_searcher(haystack).next_match().is_some()
971-
}
972-
_ => self == haystack,
973-
}
974-
}
975-
976949
/// Removes the pattern from the front of haystack, if it matches.
977950
#[inline]
978951
fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> {
@@ -1711,208 +1684,3 @@ impl TwoWayStrategy for RejectAndMatch {
17111684
SearchStep::Match(a, b)
17121685
}
17131686
}
1714-
1715-
/// SIMD search for short needles based on
1716-
/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
1717-
///
1718-
/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
1719-
/// does) by probing the first and last byte of the needle for the whole vector width
1720-
/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
1721-
///
1722-
/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
1723-
/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
1724-
/// should be evaluated.
1725-
///
1726-
/// For haystacks smaller than vector-size + needle length it falls back to
1727-
/// a naive O(n*m) search so this implementation should not be called on larger needles.
1728-
///
1729-
/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
1730-
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
1731-
#[inline]
1732-
fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
1733-
let needle = needle.as_bytes();
1734-
let haystack = haystack.as_bytes();
1735-
1736-
debug_assert!(needle.len() > 1);
1737-
1738-
use crate::ops::BitAnd;
1739-
use crate::simd::mask8x16 as Mask;
1740-
use crate::simd::u8x16 as Block;
1741-
use crate::simd::{SimdPartialEq, ToBitMask};
1742-
1743-
let first_probe = needle[0];
1744-
1745-
// the offset used for the 2nd vector
1746-
let second_probe_offset = if needle.len() == 2 {
1747-
// never bail out on len=2 needles because the probes will fully cover them and have
1748-
// no degenerate cases.
1749-
1
1750-
} else {
1751-
// try a few bytes in case first and last byte of the needle are the same
1752-
let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
1753-
// fall back to other search methods if we can't find any different bytes
1754-
// since we could otherwise hit some degenerate cases
1755-
return None;
1756-
};
1757-
second_probe_offset
1758-
};
1759-
1760-
// do a naive search if the haystack is too small to fit
1761-
if haystack.len() < Block::LANES + second_probe_offset {
1762-
return Some(haystack.windows(needle.len()).any(|c| c == needle));
1763-
}
1764-
1765-
let first_probe: Block = Block::splat(first_probe);
1766-
let second_probe: Block = Block::splat(needle[second_probe_offset]);
1767-
// first byte are already checked by the outer loop. to verify a match only the
1768-
// remainder has to be compared.
1769-
let trimmed_needle = &needle[1..];
1770-
1771-
// this #[cold] is load-bearing, benchmark before removing it...
1772-
let check_mask = #[cold]
1773-
|idx, mask: u16, skip: bool| -> bool {
1774-
if skip {
1775-
return false;
1776-
}
1777-
1778-
// and so is this. optimizations are weird.
1779-
let mut mask = mask;
1780-
1781-
while mask != 0 {
1782-
let trailing = mask.trailing_zeros();
1783-
let offset = idx + trailing as usize + 1;
1784-
// SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
1785-
// and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
1786-
unsafe {
1787-
let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
1788-
if small_slice_eq(sub, trimmed_needle) {
1789-
return true;
1790-
}
1791-
}
1792-
mask &= !(1 << trailing);
1793-
}
1794-
return false;
1795-
};
1796-
1797-
let test_chunk = |idx| -> u16 {
1798-
// SAFETY: this requires at least LANES bytes being readable at idx
1799-
// that is ensured by the loop ranges (see comments below)
1800-
let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
1801-
// SAFETY: this requires LANES + block_offset bytes being readable at idx
1802-
let b: Block = unsafe {
1803-
haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
1804-
};
1805-
let eq_first: Mask = a.simd_eq(first_probe);
1806-
let eq_last: Mask = b.simd_eq(second_probe);
1807-
let both = eq_first.bitand(eq_last);
1808-
let mask = both.to_bitmask();
1809-
1810-
return mask;
1811-
};
1812-
1813-
let mut i = 0;
1814-
let mut result = false;
1815-
// The loop condition must ensure that there's enough headroom to read LANE bytes,
1816-
// and not only at the current index but also at the index shifted by block_offset
1817-
const UNROLL: usize = 4;
1818-
while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
1819-
let mut masks = [0u16; UNROLL];
1820-
for j in 0..UNROLL {
1821-
masks[j] = test_chunk(i + j * Block::LANES);
1822-
}
1823-
for j in 0..UNROLL {
1824-
let mask = masks[j];
1825-
if mask != 0 {
1826-
result |= check_mask(i + j * Block::LANES, mask, result);
1827-
}
1828-
}
1829-
i += UNROLL * Block::LANES;
1830-
}
1831-
while i + second_probe_offset + Block::LANES < haystack.len() && !result {
1832-
let mask = test_chunk(i);
1833-
if mask != 0 {
1834-
result |= check_mask(i, mask, result);
1835-
}
1836-
i += Block::LANES;
1837-
}
1838-
1839-
// Process the tail that didn't fit into LANES-sized steps.
1840-
// This simply repeats the same procedure but as right-aligned chunk instead
1841-
// of a left-aligned one. The last byte must be exactly flush with the string end so
1842-
// we don't miss a single byte or read out of bounds.
1843-
let i = haystack.len() - second_probe_offset - Block::LANES;
1844-
let mask = test_chunk(i);
1845-
if mask != 0 {
1846-
result |= check_mask(i, mask, result);
1847-
}
1848-
1849-
Some(result)
1850-
}
1851-
1852-
/// Compares short slices for equality.
1853-
///
1854-
/// It avoids a call to libc's memcmp which is faster on long slices
1855-
/// due to SIMD optimizations but it incurs a function call overhead.
1856-
///
1857-
/// # Safety
1858-
///
1859-
/// Both slices must have the same length.
1860-
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
1861-
#[inline]
1862-
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
1863-
// This function is adapted from
1864-
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
1865-
1866-
// If we don't have enough bytes to do 4-byte at a time loads, then
1867-
// fall back to the naive slow version.
1868-
//
1869-
// Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
1870-
// of a loop. Benchmark it.
1871-
if x.len() < 4 {
1872-
for (&b1, &b2) in x.iter().zip(y) {
1873-
if b1 != b2 {
1874-
return false;
1875-
}
1876-
}
1877-
return true;
1878-
}
1879-
// When we have 4 or more bytes to compare, then proceed in chunks of 4 at
1880-
// a time using unaligned loads.
1881-
//
1882-
// Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
1883-
// that this particular version of memcmp is likely to be called with tiny
1884-
// needles. That means that if we do 8 byte loads, then a higher proportion
1885-
// of memcmp calls will use the slower variant above. With that said, this
1886-
// is a hypothesis and is only loosely supported by benchmarks. There's
1887-
// likely some improvement that could be made here. The main thing here
1888-
// though is to optimize for latency, not throughput.
1889-
1890-
// SAFETY: Via the conditional above, we know that both `px` and `py`
1891-
// have the same length, so `px < pxend` implies that `py < pyend`.
1892-
// Thus, derefencing both `px` and `py` in the loop below is safe.
1893-
//
1894-
// Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
1895-
// end of of `px` and `py`. Thus, the final dereference outside of the
1896-
// loop is guaranteed to be valid. (The final comparison will overlap with
1897-
// the last comparison done in the loop for lengths that aren't multiples
1898-
// of four.)
1899-
//
1900-
// Finally, we needn't worry about alignment here, since we do unaligned
1901-
// loads.
1902-
unsafe {
1903-
let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
1904-
let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
1905-
while px < pxend {
1906-
let vx = (px as *const u32).read_unaligned();
1907-
let vy = (py as *const u32).read_unaligned();
1908-
if vx != vy {
1909-
return false;
1910-
}
1911-
px = px.add(4);
1912-
py = py.add(4);
1913-
}
1914-
let vx = (pxend as *const u32).read_unaligned();
1915-
let vy = (pyend as *const u32).read_unaligned();
1916-
vx == vy
1917-
}
1918-
}

0 commit comments

Comments
 (0)