Skip to content

Commit 0e7589e

Browse files
authored
Merge pull request #9 from bashtage/choice-dtype
MAINT: Simplify return types
2 parents cb6f40f + 6c4bc0c commit 0e7589e

File tree

4 files changed

+125
-43
lines changed

4 files changed

+125
-43
lines changed

numpy/random/generator.pyx

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -368,26 +368,11 @@ cdef class RandomGenerator:
368368
[ True, True]]])
369369
370370
"""
371-
cdef np.npy_intp n
372-
cdef np.ndarray randoms
373-
cdef int64_t *randoms_data
374-
375-
if size is None:
376-
with self.lock:
377-
return random_positive_int(self._brng)
378-
379-
randoms = <np.ndarray>np.empty(size, dtype=np.int64)
380-
randoms_data = <int64_t*>np.PyArray_DATA(randoms)
381-
n = np.PyArray_SIZE(randoms)
382-
383-
for i in range(n):
384-
with self.lock, nogil:
385-
randoms_data[i] = random_positive_int(self._brng)
386-
return randoms
371+
return self.randint(0, np.iinfo(np.int).max + 1, dtype=np.int, size=size)
387372

388-
def randint(self, low, high=None, size=None, dtype=int, use_masked=True):
373+
def randint(self, low, high=None, size=None, dtype=np.int64, use_masked=True):
389374
"""
390-
randint(low, high=None, size=None, dtype='l', use_masked=True)
375+
randint(low, high=None, size=None, dtype='int64', use_masked=True)
391376
392377
Return random integers from `low` (inclusive) to `high` (exclusive).
393378
@@ -530,9 +515,9 @@ cdef class RandomGenerator:
530515
return self.randint(0, 4294967296, size=n_uint32, dtype=np.uint32).tobytes()[:length]
531516

532517
@cython.wraparound(True)
533-
def choice(self, a, size=None, replace=True, p=None):
518+
def choice(self, a, size=None, replace=True, p=None, axis=0):
534519
"""
535-
choice(a, size=None, replace=True, p=None)
520+
choice(a, size=None, replace=True, p=None, axis=0):
536521
537522
Generates a random sample from a given 1-D array
538523
@@ -553,6 +538,9 @@ cdef class RandomGenerator:
553538
The probabilities associated with each entry in a.
554539
If not given the sample assumes a uniform distribution over all
555540
entries in a.
541+
axis : int, optional
542+
The axis along which the selection is performed. The default, 0,
543+
selects by row.
556544
557545
Returns
558546
-------
@@ -562,11 +550,11 @@ cdef class RandomGenerator:
562550
Raises
563551
------
564552
ValueError
565-
If a is an int and less than zero, if a or p are not 1-dimensional,
566-
if a is an array-like of size 0, if p is not a vector of
553+
If a is an int and less than zero, if p is not 1-dimensional, if
554+
a is array-like with a size 0, if p is not a vector of
567555
probabilities, if a and p have different lengths, or if
568556
replace=False and the sample size is greater than the population
569-
size
557+
size.
570558
571559
See Also
572560
--------
@@ -607,7 +595,14 @@ cdef class RandomGenerator:
607595
dtype='<U11')
608596
609597
"""
610-
598+
cdef char* idx_ptr
599+
cdef int64_t buf
600+
cdef char* buf_ptr
601+
602+
cdef set idx_set
603+
cdef int64_t val, t, loc, size_i, pop_size_i
604+
cdef int64_t *idx_data
605+
cdef np.npy_intp j
611606
# Format and Verify input
612607
a = np.array(a, copy=False)
613608
if a.ndim == 0:
@@ -618,11 +613,9 @@ cdef class RandomGenerator:
618613
raise ValueError("a must be 1-dimensional or an integer")
619614
if pop_size <= 0 and np.prod(size) != 0:
620615
raise ValueError("a must be greater than 0 unless no samples are taken")
621-
elif a.ndim != 1:
622-
raise ValueError("a must be 1-dimensional")
623616
else:
624-
pop_size = a.shape[0]
625-
if pop_size is 0 and np.prod(size) != 0:
617+
pop_size = a.shape[axis]
618+
if pop_size == 0 and np.prod(size) != 0:
626619
raise ValueError("'a' cannot be empty unless no samples are taken")
627620

628621
if p is not None:
@@ -661,9 +654,9 @@ cdef class RandomGenerator:
661654
cdf /= cdf[-1]
662655
uniform_samples = self.random_sample(shape)
663656
idx = cdf.searchsorted(uniform_samples, side='right')
664-
idx = np.array(idx, copy=False) # searchsorted returns a scalar
657+
idx = np.array(idx, copy=False, dtype=np.int64) # searchsorted returns a scalar
665658
else:
666-
idx = self.randint(0, pop_size, size=shape)
659+
idx = self.randint(0, pop_size, size=shape, dtype=np.int64)
667660
else:
668661
if size > pop_size:
669662
raise ValueError("Cannot take a larger sample than "
@@ -692,7 +685,39 @@ cdef class RandomGenerator:
692685
n_uniq += new.size
693686
idx = found
694687
else:
695-
idx = self.permutation(pop_size)[:size]
688+
size_i = size
689+
pop_size_i = pop_size
690+
# This is a heuristic tuning. should be improvable
691+
if pop_size_i > 200 and (size > 200 or size > (10 * pop_size // size)):
692+
# Tail shuffle size elements
693+
idx = np.arange(pop_size, dtype=np.int64)
694+
idx_ptr = np.PyArray_BYTES(<np.ndarray>idx)
695+
buf_ptr = <char*>&buf
696+
self._shuffle_raw(pop_size_i, max(pop_size_i - size_i,1),
697+
8, 8, idx_ptr, buf_ptr)
698+
# Copy to allow potentially large array backing idx to be gc
699+
idx = idx[(pop_size - size):].copy()
700+
else:
701+
# Floyds's algorithm with precomputed indices
702+
# Worst case, O(n**2) when size is close to pop_size
703+
idx = np.empty(size, dtype=np.int64)
704+
idx_data = <int64_t*>np.PyArray_DATA(<np.ndarray>idx)
705+
idx_set = set()
706+
loc = 0
707+
# Sample indices with one pass to avoid reacquiring the lock
708+
with self.lock:
709+
for j in range(pop_size_i - size_i, pop_size_i):
710+
idx_data[loc] = random_interval(self._brng, j)
711+
loc += 1
712+
loc = 0
713+
while len(idx_set) < size_i:
714+
for j in range(pop_size_i - size_i, pop_size_i):
715+
if idx_data[loc] not in idx_set:
716+
val = idx_data[loc]
717+
else:
718+
idx_data[loc] = val = j
719+
idx_set.add(val)
720+
loc += 1
696721
if shape is not None:
697722
idx.shape = shape
698723

@@ -714,7 +739,9 @@ cdef class RandomGenerator:
714739
res[()] = a[idx]
715740
return res
716741

717-
return a[idx]
742+
# asarray downcasts on 32-bit platforms, always safe
743+
# no-op on 64-bit platforms
744+
return a.take(np.asarray(idx, dtype=np.intp), axis=axis)
718745

719746
def uniform(self, low=0.0, high=1.0, size=None):
720747
"""
@@ -3986,9 +4013,9 @@ cdef class RandomGenerator:
39864013
# the most common case, yielding a ~33% performance improvement.
39874014
# Note that apparently, only one branch can ever be specialized.
39884015
if itemsize == sizeof(np.npy_intp):
3989-
self._shuffle_raw(n, sizeof(np.npy_intp), stride, x_ptr, buf_ptr)
4016+
self._shuffle_raw(n, 1, sizeof(np.npy_intp), stride, x_ptr, buf_ptr)
39904017
else:
3991-
self._shuffle_raw(n, itemsize, stride, x_ptr, buf_ptr)
4018+
self._shuffle_raw(n, 1, itemsize, stride, x_ptr, buf_ptr)
39924019
elif isinstance(x, np.ndarray) and x.ndim and x.size:
39934020
buf = np.empty_like(x[0, ...])
39944021
with self.lock:
@@ -4007,10 +4034,29 @@ cdef class RandomGenerator:
40074034
j = random_interval(self._brng, i)
40084035
x[i], x[j] = x[j], x[i]
40094036

4010-
cdef inline _shuffle_raw(self, np.npy_intp n, np.npy_intp itemsize,
4011-
np.npy_intp stride, char* data, char* buf):
4037+
cdef inline _shuffle_raw(self, np.npy_intp n, np.npy_intp first,
4038+
np.npy_intp itemsize, np.npy_intp stride,
4039+
char* data, char* buf):
4040+
"""
4041+
Parameters
4042+
----------
4043+
n
4044+
Number of elements in data
4045+
first
4046+
First observation to shuffle. Shuffles n-1,
4047+
n-2, ..., first, so that when first=1 the entire
4048+
array is shuffled
4049+
itemsize
4050+
Size in bytes of item
4051+
stride
4052+
Array stride
4053+
data
4054+
Location of data
4055+
buf
4056+
Location of buffer (itemsize)
4057+
"""
40124058
cdef np.npy_intp i, j
4013-
for i in reversed(range(1, n)):
4059+
for i in reversed(range(first, n)):
40144060
j = random_interval(self._brng, i)
40154061
string.memcpy(buf, data + j * stride, itemsize)
40164062
string.memcpy(data + j * stride, data + i * stride, itemsize)

numpy/random/src/distributions/distributions.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ int64_t random_zipf(brng_t *brng_state, double a) {
10701070

10711071
T = pow(1.0 + 1.0 / X, am1);
10721072
if (V * X * (T - 1.0) / (b - 1.0) <= T / b) {
1073-
return (long)X;
1073+
return (int64_t)X;
10741074
}
10751075
}
10761076
}

numpy/random/tests/test_against_numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def test_standard_exponential(self):
183183
self.rs.standard_exponential)
184184
self._is_state_common_legacy()
185185

186+
@pytest.mark.xfail(reason='Stream broken for simplicity')
186187
def test_tomaxint(self):
187188
self._set_common_state()
188189
self._is_state_common()
@@ -327,6 +328,7 @@ def test_multinomial(self):
327328
g(100, np.array(p), size=(7, 23)))
328329
self._is_state_common()
329330

331+
@pytest.mark.xfail(reason='Stream broken for performance')
330332
def test_choice(self):
331333
self._set_common_state()
332334
self._is_state_common()

numpy/random/tests/test_generator_mt19937.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -542,25 +542,25 @@ def test_random_sample_unsupported_type(self):
542542
def test_choice_uniform_replace(self):
543543
random.brng.seed(self.seed)
544544
actual = random.choice(4, 4)
545-
desired = np.array([2, 3, 2, 3])
545+
desired = np.array([2, 3, 2, 3], dtype=np.int64)
546546
assert_array_equal(actual, desired)
547547

548548
def test_choice_nonuniform_replace(self):
549549
random.brng.seed(self.seed)
550550
actual = random.choice(4, 4, p=[0.4, 0.4, 0.1, 0.1])
551-
desired = np.array([1, 1, 2, 2])
551+
desired = np.array([1, 1, 2, 2], dtype=np.int64)
552552
assert_array_equal(actual, desired)
553553

554554
def test_choice_uniform_noreplace(self):
555555
random.brng.seed(self.seed)
556556
actual = random.choice(4, 3, replace=False)
557-
desired = np.array([0, 1, 3])
557+
desired = np.array([0, 2, 3], dtype=np.int64)
558558
assert_array_equal(actual, desired)
559559

560560
def test_choice_nonuniform_noreplace(self):
561561
random.brng.seed(self.seed)
562562
actual = random.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1])
563-
desired = np.array([2, 3, 1])
563+
desired = np.array([2, 3, 1], dtype=np.int64)
564564
assert_array_equal(actual, desired)
565565

566566
def test_choice_noninteger(self):
@@ -569,11 +569,22 @@ def test_choice_noninteger(self):
569569
desired = np.array(['c', 'd', 'c', 'd'])
570570
assert_array_equal(actual, desired)
571571

572+
def test_choice_multidimensional_default_axis(self):
573+
random.brng.seed(self.seed)
574+
actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 3)
575+
desired = np.array([[4, 5], [6, 7], [4, 5]])
576+
assert_array_equal(actual, desired)
577+
578+
def test_choice_multidimensional_custom_axis(self):
579+
random.brng.seed(self.seed)
580+
actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 1, axis=1)
581+
desired = np.array([[0], [2], [4], [6]])
582+
assert_array_equal(actual, desired)
583+
572584
def test_choice_exceptions(self):
573585
sample = random.choice
574586
assert_raises(ValueError, sample, -1, 3)
575587
assert_raises(ValueError, sample, 3., 3)
576-
assert_raises(ValueError, sample, [[1, 2], [3, 4]], 3)
577588
assert_raises(ValueError, sample, [], 3)
578589
assert_raises(ValueError, sample, [1, 2, 3, 4], 3,
579590
p=[[0.25, 0.25], [0.25, 0.25]])
@@ -639,6 +650,29 @@ def test_choice_nan_probabilities(self):
639650
p = [None, None, None]
640651
assert_raises(ValueError, random.choice, a, p=p)
641652

653+
def test_choice_return_type(self):
654+
# gh 9867
655+
p = np.ones(4) / 4.
656+
actual = random.choice(4, 2)
657+
assert actual.dtype == np.int64
658+
actual = random.choice(4, 2, replace=False)
659+
assert actual.dtype == np.int64
660+
actual = random.choice(4, 2, p=p)
661+
assert actual.dtype == np.int64
662+
actual = random.choice(4, 2, p=p, replace=False)
663+
assert actual.dtype == np.int64
664+
665+
def test_choice_large_sample(self):
666+
import hashlib
667+
668+
choice_hash = '6395868be877d27518c832213c17977c'
669+
random.brng.seed(self.seed)
670+
actual = random.choice(10000, 5000, replace=False)
671+
if sys.byteorder != 'little':
672+
actual = actual.byteswap()
673+
res = hashlib.md5(actual.view(np.int8)).hexdigest()
674+
assert_(choice_hash == res)
675+
642676
def test_bytes(self):
643677
random.brng.seed(self.seed)
644678
actual = random.bytes(10)

0 commit comments

Comments
 (0)