Skip to content

Commit 053ad81

Browse files
committed
slices: rework the APIs of BinarySearch*
For golang/go#50340 Change-Id: If115b2b66d463d5f3788d017924f8dd38867551c Reviewed-on: https://go-review.googlesource.com/c/exp/+/395414 Reviewed-by: Ian Lance Taylor <[email protected]> Trust: Eli Bendersky‎ <[email protected]>
1 parent 054d857 commit 053ad81

File tree

2 files changed

+127
-35
lines changed

2 files changed

+127
-35
lines changed

slices/sort.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,34 @@ func IsSortedFunc[E any](x []E, less func(a, b E) bool) bool {
4646
return true
4747
}
4848

49-
// BinarySearch searches for target in a sorted slice and returns the smallest
50-
// index at which target is found. If the target is not found, the index at
51-
// which it could be inserted into the slice is returned; therefore, if the
52-
// intention is to find target itself a separate check for equality with the
53-
// element at the returned index is required.
54-
func BinarySearch[E constraints.Ordered](x []E, target E) int {
55-
return search(len(x), func(i int) bool { return x[i] >= target })
49+
// BinarySearch searches for target in a sorted slice and returns the position
50+
// where target is found, or the position where target would appear in the
51+
// sort order; it also returns a bool saying whether the target is really found
52+
// in the slice. The slice must be sorted in increasing order.
53+
func BinarySearch[E constraints.Ordered](x []E, target E) (int, bool) {
54+
// search returns the leftmost position where f returns true, or len(x) if f
55+
// returns false for all x. This is the insertion position for target in x,
56+
// and could point to an element that's either == target or not.
57+
pos := search(len(x), func(i int) bool { return x[i] >= target })
58+
if pos >= len(x) || x[pos] != target {
59+
return pos, false
60+
} else {
61+
return pos, true
62+
}
5663
}
5764

58-
// BinarySearchFunc uses binary search to find and return the smallest index i
59-
// in [0, n) at which ok(i) is true, assuming that on the range [0, n),
60-
// ok(i) == true implies ok(i+1) == true. That is, BinarySearchFunc requires
61-
// that ok is false for some (possibly empty) prefix of the input range [0, n)
62-
// and then true for the (possibly empty) remainder; BinarySearchFunc returns
63-
// the first true index. If there is no such index, BinarySearchFunc returns n.
64-
// (Note that the "not found" return value is not -1 as in, for instance,
65-
// strings.Index.) Search calls ok(i) only for i in the range [0, n).
66-
func BinarySearchFunc[E any](x []E, ok func(E) bool) int {
67-
return search(len(x), func(i int) bool { return ok(x[i]) })
65+
// BinarySearchFunc works like BinarySearch, but uses a custom comparison
66+
// function. The slice must be sorted in increasing order, where "increasing" is
67+
// defined by cmp. cmp(a, b) is expected to return an integer comparing the two
68+
// parameters: 0 if a == b, a negative number if a < b and a positive number if
69+
// a > b.
70+
func BinarySearchFunc[E any](x []E, target E, cmp func(E, E) int) (int, bool) {
71+
pos := search(len(x), func(i int) bool { return cmp(x[i], target) >= 0 })
72+
if pos >= len(x) || cmp(x[pos], target) != 0 {
73+
return pos, false
74+
} else {
75+
return pos, true
76+
}
6877
}
6978

7079
// maxDepth returns a threshold at which quicksort should switch

slices/sort_test.go

Lines changed: 101 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ package slices
77
import (
88
"math"
99
"math/rand"
10+
"strconv"
11+
"strings"
1012
"testing"
1113
)
1214

@@ -151,31 +153,112 @@ func TestStability(t *testing.T) {
151153
}
152154

153155
func TestBinarySearch(t *testing.T) {
154-
data := []string{"aa", "ad", "ca", "xy"}
156+
str1 := []string{"foo"}
157+
str2 := []string{"ab", "ca"}
158+
str3 := []string{"mo", "qo", "vo"}
159+
str4 := []string{"ab", "ad", "ca", "xy"}
160+
161+
// slice with repeating elements
162+
strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}
163+
164+
// slice with all element equal
165+
strSame := []string{"xx", "xx", "xx"}
166+
155167
tests := []struct {
156-
target string
157-
want int
168+
data []string
169+
target string
170+
wantPos int
171+
wantFound bool
158172
}{
159-
{"aa", 0},
160-
{"ab", 1},
161-
{"ad", 1},
162-
{"ax", 2},
163-
{"ca", 2},
164-
{"cc", 3},
165-
{"dd", 3},
166-
{"xy", 3},
167-
{"zz", 4},
173+
{[]string{}, "foo", 0, false},
174+
{[]string{}, "", 0, false},
175+
176+
{str1, "foo", 0, true},
177+
{str1, "bar", 0, false},
178+
{str1, "zx", 1, false},
179+
180+
{str2, "aa", 0, false},
181+
{str2, "ab", 0, true},
182+
{str2, "ad", 1, false},
183+
{str2, "ca", 1, true},
184+
{str2, "ra", 2, false},
185+
186+
{str3, "bb", 0, false},
187+
{str3, "mo", 0, true},
188+
{str3, "nb", 1, false},
189+
{str3, "qo", 1, true},
190+
{str3, "tr", 2, false},
191+
{str3, "vo", 2, true},
192+
{str3, "xr", 3, false},
193+
194+
{str4, "aa", 0, false},
195+
{str4, "ab", 0, true},
196+
{str4, "ac", 1, false},
197+
{str4, "ad", 1, true},
198+
{str4, "ax", 2, false},
199+
{str4, "ca", 2, true},
200+
{str4, "cc", 3, false},
201+
{str4, "dd", 3, false},
202+
{str4, "xy", 3, true},
203+
{str4, "zz", 4, false},
204+
205+
{strRepeats, "da", 2, true},
206+
{strRepeats, "db", 5, false},
207+
{strRepeats, "ma", 6, true},
208+
{strRepeats, "mb", 8, false},
209+
210+
{strSame, "xx", 0, true},
211+
{strSame, "ab", 0, false},
212+
{strSame, "zz", 3, false},
168213
}
169214
for _, tt := range tests {
170215
t.Run(tt.target, func(t *testing.T) {
171-
i := BinarySearch(data, tt.target)
172-
if i != tt.want {
173-
t.Errorf("BinarySearch want %d, got %d", tt.want, i)
216+
{
217+
pos, found := BinarySearch(tt.data, tt.target)
218+
if pos != tt.wantPos || found != tt.wantFound {
219+
t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
220+
}
221+
}
222+
223+
{
224+
pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
225+
if pos != tt.wantPos || found != tt.wantFound {
226+
t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
227+
}
228+
}
229+
})
230+
}
231+
}
232+
233+
func TestBinarySearchInts(t *testing.T) {
234+
data := []int{20, 30, 40, 50, 60, 70, 80, 90}
235+
tests := []struct {
236+
target int
237+
wantPos int
238+
wantFound bool
239+
}{
240+
{20, 0, true},
241+
{23, 1, false},
242+
{43, 3, false},
243+
{80, 6, true},
244+
}
245+
for _, tt := range tests {
246+
t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
247+
{
248+
pos, found := BinarySearch(data, tt.target)
249+
if pos != tt.wantPos || found != tt.wantFound {
250+
t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
251+
}
174252
}
175253

176-
j := BinarySearchFunc(data, func(s string) bool { return s >= tt.target })
177-
if j != tt.want {
178-
t.Errorf("BinarySearchFunc want %d, got %d", tt.want, j)
254+
{
255+
cmp := func(a, b int) int {
256+
return a - b
257+
}
258+
pos, found := BinarySearchFunc(data, tt.target, cmp)
259+
if pos != tt.wantPos || found != tt.wantFound {
260+
t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
261+
}
179262
}
180263
})
181264
}

0 commit comments

Comments
 (0)