Skip to content

Commit b2f9bea

Browse files
bashtagemattip
authored andcommitted
ENH: Improvce choice without replacement
Improve performance in all cases Large improvement with size is small xref numpy#5299 xref numpy#2764 xref numpy#9855 xref numpy#7810
1 parent 0f931b3 commit b2f9bea

File tree

3 files changed

+106
-22
lines changed

3 files changed

+106
-22
lines changed

numpy/random/generator.pyx

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,9 @@ cdef class RandomGenerator:
515515
return self.randint(0, 4294967296, size=n_uint32, dtype=np.uint32).tobytes()[:length]
516516

517517
@cython.wraparound(True)
518-
def choice(self, a, size=None, replace=True, p=None):
518+
def choice(self, a, size=None, replace=True, p=None, axis=0):
519519
"""
520-
choice(a, size=None, replace=True, p=None)
520+
choice(a, size=None, replace=True, p=None, axis=0):
521521
522522
Generates a random sample from a given 1-D array
523523
@@ -538,6 +538,9 @@ cdef class RandomGenerator:
538538
The probabilities associated with each entry in a.
539539
If not given the sample assumes a uniform distribution over all
540540
entries in a.
541+
axis : int, optional
542+
The axis along which the selection is performed. The default, 0,
543+
selects by row.
541544
542545
Returns
543546
-------
@@ -547,11 +550,11 @@ cdef class RandomGenerator:
547550
Raises
548551
------
549552
ValueError
550-
If a is an int and less than zero, if a or p are not 1-dimensional,
551-
if a is an array-like of size 0, if p is not a vector of
553+
If a is an int and less than zero, if p is not 1-dimensional, if
554+
a is array-like with a size 0, if p is not a vector of
552555
probabilities, if a and p have different lengths, or if
553556
replace=False and the sample size is greater than the population
554-
size
557+
size.
555558
556559
See Also
557560
--------
@@ -592,7 +595,14 @@ cdef class RandomGenerator:
592595
dtype='<U11')
593596
594597
"""
595-
598+
cdef char* idx_ptr
599+
cdef int64_t buf
600+
cdef char* buf_ptr
601+
602+
cdef set idx_set
603+
cdef int64_t val, t, loc, size_i, pop_size_i
604+
cdef int64_t *idx_data
605+
cdef np.npy_intp j
596606
# Format and Verify input
597607
a = np.array(a, copy=False)
598608
if a.ndim == 0:
@@ -603,11 +613,9 @@ cdef class RandomGenerator:
603613
raise ValueError("a must be 1-dimensional or an integer")
604614
if pop_size <= 0 and np.prod(size) != 0:
605615
raise ValueError("a must be greater than 0 unless no samples are taken")
606-
elif a.ndim != 1:
607-
raise ValueError("a must be 1-dimensional")
608616
else:
609-
pop_size = a.shape[0]
610-
if pop_size is 0 and np.prod(size) != 0:
617+
pop_size = a.shape[axis]
618+
if pop_size == 0 and np.prod(size) != 0:
611619
raise ValueError("'a' cannot be empty unless no samples are taken")
612620

613621
if p is not None:
@@ -677,7 +685,39 @@ cdef class RandomGenerator:
677685
n_uniq += new.size
678686
idx = found
679687
else:
680-
idx = (self.permutation(pop_size)[:size]).astype(np.int64)
688+
size_i = size
689+
pop_size_i = pop_size
690+
# This is a heuristic tuning. should be improvable
691+
if pop_size_i > 200 and (size > 200 or size > (10 * pop_size // size)):
692+
# Tail shuffle size elements
693+
idx = np.arange(pop_size, dtype=np.int64)
694+
idx_ptr = np.PyArray_BYTES(<np.ndarray>idx)
695+
buf_ptr = <char*>&buf
696+
self._shuffle_raw(pop_size_i, max(pop_size_i - size_i,1),
697+
8, 8, idx_ptr, buf_ptr)
698+
# Copy to allow potentially large array backing idx to be gc
699+
idx = idx[(pop_size - size):].copy()
700+
else:
701+
# Floyds's algorithm with precomputed indices
702+
# Worst case, O(n**2) when size is close to pop_size
703+
idx = np.empty(size, dtype=np.int64)
704+
idx_data = <int64_t*>np.PyArray_DATA(<np.ndarray>idx)
705+
idx_set = set()
706+
loc = 0
707+
# Sample indices with one pass to avoid reacquiring the lock
708+
with self.lock:
709+
for j in range(pop_size_i - size_i, pop_size_i):
710+
idx_data[loc] = random_interval(self._brng, j)
711+
loc += 1
712+
loc = 0
713+
while len(idx_set) < size_i:
714+
for j in range(pop_size_i - size_i, pop_size_i):
715+
if idx_data[loc] not in idx_set:
716+
val = idx_data[loc]
717+
else:
718+
idx_data[loc] = val = j
719+
idx_set.add(val)
720+
loc += 1
681721
if shape is not None:
682722
idx.shape = shape
683723

@@ -699,7 +739,9 @@ cdef class RandomGenerator:
699739
res[()] = a[idx]
700740
return res
701741

702-
return a[idx]
742+
# asarray downcasts on 32-bit platforms, always safe
743+
# no-op on 64-bit platforms
744+
return a.take(np.asarray(idx, dtype=np.intp), axis=axis)
703745

704746
def uniform(self, low=0.0, high=1.0, size=None):
705747
"""
@@ -3971,9 +4013,9 @@ cdef class RandomGenerator:
39714013
# the most common case, yielding a ~33% performance improvement.
39724014
# Note that apparently, only one branch can ever be specialized.
39734015
if itemsize == sizeof(np.npy_intp):
3974-
self._shuffle_raw(n, sizeof(np.npy_intp), stride, x_ptr, buf_ptr)
4016+
self._shuffle_raw(n, 1, sizeof(np.npy_intp), stride, x_ptr, buf_ptr)
39754017
else:
3976-
self._shuffle_raw(n, itemsize, stride, x_ptr, buf_ptr)
4018+
self._shuffle_raw(n, 1, itemsize, stride, x_ptr, buf_ptr)
39774019
elif isinstance(x, np.ndarray) and x.ndim and x.size:
39784020
buf = np.empty_like(x[0, ...])
39794021
with self.lock:
@@ -3992,10 +4034,29 @@ cdef class RandomGenerator:
39924034
j = random_interval(self._brng, i)
39934035
x[i], x[j] = x[j], x[i]
39944036

3995-
cdef inline _shuffle_raw(self, np.npy_intp n, np.npy_intp itemsize,
3996-
np.npy_intp stride, char* data, char* buf):
4037+
cdef inline _shuffle_raw(self, np.npy_intp n, np.npy_intp first,
4038+
np.npy_intp itemsize, np.npy_intp stride,
4039+
char* data, char* buf):
4040+
"""
4041+
Parameters
4042+
----------
4043+
n
4044+
Number of elements in data
4045+
first
4046+
First observation to shuffle. Shuffles n-1,
4047+
n-2, ..., first, so that when first=1 the entire
4048+
array is shuffled
4049+
itemsize
4050+
Size in bytes of item
4051+
stride
4052+
Array stride
4053+
data
4054+
Location of data
4055+
buf
4056+
Location of buffer (itemsize)
4057+
"""
39974058
cdef np.npy_intp i, j
3998-
for i in reversed(range(1, n)):
4059+
for i in reversed(range(first, n)):
39994060
j = random_interval(self._brng, i)
40004061
string.memcpy(buf, data + j * stride, itemsize)
40014062
string.memcpy(data + j * stride, data + i * stride, itemsize)

numpy/random/tests/test_against_numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def test_multinomial(self):
328328
g(100, np.array(p), size=(7, 23)))
329329
self._is_state_common()
330330

331+
@pytest.mark.xfail(reason='Stream broken for performance')
331332
def test_choice(self):
332333
self._set_common_state()
333334
self._is_state_common()

numpy/random/tests/test_generator_mt19937.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -542,25 +542,25 @@ def test_random_sample_unsupported_type(self):
542542
def test_choice_uniform_replace(self):
543543
random.brng.seed(self.seed)
544544
actual = random.choice(4, 4)
545-
desired = np.array([2, 3, 2, 3])
545+
desired = np.array([2, 3, 2, 3], dtype=np.int64)
546546
assert_array_equal(actual, desired)
547547

548548
def test_choice_nonuniform_replace(self):
549549
random.brng.seed(self.seed)
550550
actual = random.choice(4, 4, p=[0.4, 0.4, 0.1, 0.1])
551-
desired = np.array([1, 1, 2, 2])
551+
desired = np.array([1, 1, 2, 2], dtype=np.int64)
552552
assert_array_equal(actual, desired)
553553

554554
def test_choice_uniform_noreplace(self):
555555
random.brng.seed(self.seed)
556556
actual = random.choice(4, 3, replace=False)
557-
desired = np.array([0, 1, 3])
557+
desired = np.array([0, 2, 3], dtype=np.int64)
558558
assert_array_equal(actual, desired)
559559

560560
def test_choice_nonuniform_noreplace(self):
561561
random.brng.seed(self.seed)
562562
actual = random.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1])
563-
desired = np.array([2, 3, 1])
563+
desired = np.array([2, 3, 1], dtype=np.int64)
564564
assert_array_equal(actual, desired)
565565

566566
def test_choice_noninteger(self):
@@ -569,11 +569,22 @@ def test_choice_noninteger(self):
569569
desired = np.array(['c', 'd', 'c', 'd'])
570570
assert_array_equal(actual, desired)
571571

572+
def test_choice_multidimensional_default_axis(self):
573+
random.brng.seed(self.seed)
574+
actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 3)
575+
desired = np.array([[4, 5], [6, 7], [4, 5]])
576+
assert_array_equal(actual, desired)
577+
578+
def test_choice_multidimensional_custom_axis(self):
579+
random.brng.seed(self.seed)
580+
actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 1, axis=1)
581+
desired = np.array([[0], [2], [4], [6]])
582+
assert_array_equal(actual, desired)
583+
572584
def test_choice_exceptions(self):
573585
sample = random.choice
574586
assert_raises(ValueError, sample, -1, 3)
575587
assert_raises(ValueError, sample, 3., 3)
576-
assert_raises(ValueError, sample, [[1, 2], [3, 4]], 3)
577588
assert_raises(ValueError, sample, [], 3)
578589
assert_raises(ValueError, sample, [1, 2, 3, 4], 3,
579590
p=[[0.25, 0.25], [0.25, 0.25]])
@@ -651,6 +662,17 @@ def test_choice_return_type(self):
651662
actual = random.choice(4, 2, p=p, replace=False)
652663
assert actual.dtype == np.int64
653664

665+
def test_choice_large_sample(self):
666+
import hashlib
667+
668+
choice_hash = '6395868be877d27518c832213c17977c'
669+
random.brng.seed(self.seed)
670+
actual = random.choice(10000, 5000, replace=False)
671+
if sys.byteorder != 'little':
672+
actual = actual.byteswap()
673+
res = hashlib.md5(actual.view(np.int8)).hexdigest()
674+
assert_(choice_hash == res)
675+
654676
def test_bytes(self):
655677
random.brng.seed(self.seed)
656678
actual = random.bytes(10)

0 commit comments

Comments
 (0)