Skip to content

Commit 3dc1b1e

Browse files
committed
Auto merge of #127007 - krtab:improv_binary_search, r=<try>
Improve slice::binary_search_by This PR aims to improve the performances of std::slice::binary_search. **EDIT: The proposed implementation changed so the rest of this comment is outdated. See #127007 (comment) for an up to date presentation of the PR.** It reduces the total instruction count for the `u32` monomorphization, but maybe more remarkably, removes 2 of the 12 instructions of the main loop (on x86). It changes `test_binary_search_implementation_details()` so may warrant a crater run. I will document it much more if this is shown to be interesting on benchmarks. Could we start with a timer run first? **Before the PR** ```asm mov eax, 1 test rsi, rsi je .LBB0_1 mov rcx, rdx mov rdx, rsi mov ecx, dword ptr [rcx] xor esi, esi mov r8, rdx .LBB0_3: shr rdx add rdx, rsi mov r9d, dword ptr [rdi + 4*rdx] cmp r9d, ecx je .LBB0_4 lea r10, [rdx + 1] cmp r9d, ecx cmova r8, rdx cmovb rsi, r10 mov rdx, r8 sub rdx, rsi ja .LBB0_3 mov rdx, rsi ret .LBB0_1: xor edx, edx ret .LBB0_4: xor eax, eax ret ``` **After the PR** ```asm mov ecx, dword ptr [rdx] xor eax, eax xor edx, edx .LBB1_1: cmp rsi, 1 jbe .LBB1_2 mov r9, rsi shr r9 lea r8, [r9 + rdx] sub rsi, r9 cmp dword ptr [rdi + 4*r8], ecx cmovb rdx, r8 cmova rsi, r9 jne .LBB1_1 mov rdx, r8 ret .LBB1_2: test rsi, rsi je .LBB1_3 xor eax, eax cmp dword ptr [rdi + 4*rdx], ecx setne al adc rdx, 0 ret .LBB1_3: mov eax, 1 ret ```
2 parents d38cd22 + 2f5eec9 commit 3dc1b1e

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

library/core/src/slice/mod.rs

+19-9
Original file line numberDiff line numberDiff line change
@@ -2787,15 +2787,28 @@ impl<T> [T] {
27872787
where
27882788
F: FnMut(&'a T) -> Ordering,
27892789
{
2790+
if T::IS_ZST {
2791+
let res = if self.len() == 0 {
2792+
Err(0)
2793+
} else {
2794+
match f(&self[0]) {
2795+
Less => Err(self.len()),
2796+
Equal => Ok(0),
2797+
Greater => Err(0),
2798+
}
2799+
};
2800+
return res;
2801+
}
27902802
// INVARIANTS:
2791-
// - 0 <= left <= left + size = right <= self.len()
2803+
// - 0 <= left <= right <= self.len()
27922804
// - f returns Less for everything in self[..left]
27932805
// - f returns Greater for everything in self[right..]
2794-
let mut size = self.len();
2806+
let mut right = self.len();
27952807
let mut left = 0;
2796-
let mut right = size;
27972808
while left < right {
2798-
let mid = left + size / 2;
2809+
// This is an okay way to compute the mean because left and right are
2810+
// <= isize::MAX so the addition won't overflow
2811+
let mid = (left + right) / 2;
27992812

28002813
// SAFETY: the while condition means `size` is strictly positive, so
28012814
// `size/2 < size`. Thus `left + size/2 < left + size`, which
@@ -2807,19 +2820,16 @@ impl<T> [T] {
28072820
// fewer branches and instructions than if/else or matching on
28082821
// cmp::Ordering.
28092822
// This is x86 asm for u8: https://rust.godbolt.org/z/698eYffTx.
2823+
28102824
left = if cmp == Less { mid + 1 } else { left };
28112825
right = if cmp == Greater { mid } else { right };
28122826
if cmp == Equal {
28132827
// SAFETY: same as the `get_unchecked` above
28142828
unsafe { hint::assert_unchecked(mid < self.len()) };
28152829
return Ok(mid);
28162830
}
2817-
2818-
size = right - left;
28192831
}
2820-
2821-
// SAFETY: directly true from the overall invariant.
2822-
// Note that this is `<=`, unlike the assume in the `Ok` path.
2832+
// SAFETY: yolo
28232833
unsafe { hint::assert_unchecked(left <= self.len()) };
28242834
Err(left)
28252835
}

library/core/tests/slice.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ fn test_binary_search() {
6969
assert_eq!(b.binary_search(&8), Err(5));
7070

7171
let b = [(); usize::MAX];
72-
assert_eq!(b.binary_search(&()), Ok(usize::MAX / 2));
72+
assert_eq!(b.binary_search(&()), Ok(0));
7373
}
7474

7575
#[test]
7676
fn test_binary_search_by_overflow() {
7777
let b = [(); usize::MAX];
78-
assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(usize::MAX / 2));
78+
assert_eq!(b.binary_search_by(|_| Ordering::Equal), Ok(0));
7979
assert_eq!(b.binary_search_by(|_| Ordering::Greater), Err(0));
8080
assert_eq!(b.binary_search_by(|_| Ordering::Less), Err(usize::MAX));
8181
}

0 commit comments

Comments
 (0)