Skip to content

Commit f11921d

Browse files
bashtagemattip
authored andcommitted
MAINT: Simplify return types
Standardize returns types for Windows and 32-bit platforms on int64 in choice and randint (default). Refactor tomaxint to call randint
1 parent bb7abf2 commit f11921d

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

numpy/random/generator.pyx

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -368,26 +368,11 @@ cdef class RandomGenerator:
368368
[ True, True]]])
369369
370370
"""
371-
cdef np.npy_intp n
372-
cdef np.ndarray randoms
373-
cdef int64_t *randoms_data
374-
375-
if size is None:
376-
with self.lock:
377-
return random_positive_int(self._brng)
378-
379-
randoms = <np.ndarray>np.empty(size, dtype=np.int64)
380-
randoms_data = <int64_t*>np.PyArray_DATA(randoms)
381-
n = np.PyArray_SIZE(randoms)
382-
383-
for i in range(n):
384-
with self.lock, nogil:
385-
randoms_data[i] = random_positive_int(self._brng)
386-
return randoms
371+
return self.randint(0, np.iinfo(np.int).max + 1, dtype=np.int, size=size)
387372

388-
def randint(self, low, high=None, size=None, dtype=int, use_masked=True):
373+
def randint(self, low, high=None, size=None, dtype=np.int64, use_masked=True):
389374
"""
390-
randint(low, high=None, size=None, dtype='l', use_masked=True)
375+
randint(low, high=None, size=None, dtype='int64', use_masked=True)
391376
392377
Return random integers from `low` (inclusive) to `high` (exclusive).
393378
@@ -661,9 +646,9 @@ cdef class RandomGenerator:
661646
cdf /= cdf[-1]
662647
uniform_samples = self.random_sample(shape)
663648
idx = cdf.searchsorted(uniform_samples, side='right')
664-
idx = np.array(idx, copy=False) # searchsorted returns a scalar
649+
idx = np.array(idx, copy=False, dtype=np.int64) # searchsorted returns a scalar
665650
else:
666-
idx = self.randint(0, pop_size, size=shape)
651+
idx = self.randint(0, pop_size, size=shape, dtype=np.int64)
667652
else:
668653
if size > pop_size:
669654
raise ValueError("Cannot take a larger sample than "
@@ -692,7 +677,7 @@ cdef class RandomGenerator:
692677
n_uniq += new.size
693678
idx = found
694679
else:
695-
idx = self.permutation(pop_size)[:size]
680+
idx = (self.permutation(pop_size)[:size]).astype(np.int64)
696681
if shape is not None:
697682
idx.shape = shape
698683

numpy/random/tests/test_against_numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def test_standard_exponential(self):
183183
self.rs.standard_exponential)
184184
self._is_state_common_legacy()
185185

186+
@pytest.mark.xfail(reason='Stream broken for simplicity')
186187
def test_tomaxint(self):
187188
self._set_common_state()
188189
self._is_state_common()

numpy/random/tests/test_generator_mt19937.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,18 @@ def test_choice_nan_probabilities(self):
639639
p = [None, None, None]
640640
assert_raises(ValueError, random.choice, a, p=p)
641641

642+
def test_choice_return_type(self):
643+
# gh 9867
644+
p = np.ones(4) / 4.
645+
actual = random.choice(4, 2)
646+
assert actual.dtype == np.int64
647+
actual = random.choice(4, 2, replace=False)
648+
assert actual.dtype == np.int64
649+
actual = random.choice(4, 2, p=p)
650+
assert actual.dtype == np.int64
651+
actual = random.choice(4, 2, p=p, replace=False)
652+
assert actual.dtype == np.int64
653+
642654
def test_bytes(self):
643655
random.brng.seed(self.seed)
644656
actual = random.bytes(10)

0 commit comments

Comments
 (0)