Skip to content

Commit 0a56c45

Browse files
committed
ENH: Add axis to choice
Add axis to choice xref numpy/numpy#7810
1 parent 6b6d6a6 commit 0a56c45

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

randomgen/generator.pyx

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ cdef class RandomGenerator:
577577
return self.randint(0, 4294967296, size=n_uint32, dtype=np.uint32).tobytes()[:length]
578578

579579
@cython.wraparound(True)
580-
def choice(self, a, size=None, replace=True, p=None):
580+
def choice(self, a, size=None, replace=True, p=None, axis=0):
581581
"""
582582
choice(a, size=None, replace=True, p=None)
583583
@@ -600,6 +600,9 @@ cdef class RandomGenerator:
600600
The probabilities associated with each entry in a.
601601
If not given the sample assumes a uniform distribution over all
602602
entries in a.
603+
axis : int, optional
604+
The axis along which the selection is performed. The default, 0,
605+
selects by row.
603606
604607
Returns
605608
-------
@@ -609,11 +612,11 @@ cdef class RandomGenerator:
609612
Raises
610613
------
611614
ValueError
612-
If a is an int and less than zero, if a or p are not 1-dimensional,
613-
if a is an array-like of size 0, if p is not a vector of
615+
If a is an int and less than zero, if p is not 1-dimensional, if
616+
a is array-like with a size 0, if p is not a vector of
614617
probabilities, if a and p have different lengths, or if
615618
replace=False and the sample size is greater than the population
616-
size
619+
size.
617620
618621
See Also
619622
--------
@@ -665,11 +668,9 @@ cdef class RandomGenerator:
665668
raise ValueError("a must be 1-dimensional or an integer")
666669
if pop_size <= 0 and np.prod(size) != 0:
667670
raise ValueError("a must be greater than 0 unless no samples are taken")
668-
elif a.ndim != 1:
669-
raise ValueError("a must be 1-dimensional")
670671
else:
671-
pop_size = a.shape[0]
672-
if pop_size is 0 and np.prod(size) != 0:
672+
pop_size = a.shape[axis]
673+
if pop_size == 0 and np.prod(size) != 0:
673674
raise ValueError("'a' cannot be empty unless no samples are taken")
674675

675676
if p is not None:
@@ -761,7 +762,7 @@ cdef class RandomGenerator:
761762
res[()] = a[idx]
762763
return res
763764

764-
return a[idx]
765+
return a.take(idx, axis=axis)
765766

766767
def uniform(self, low=0.0, high=1.0, size=None):
767768
"""

randomgen/tests/test_generator_mt19937.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,22 @@ def test_choice_noninteger(self):
568568
desired = np.array(['c', 'd', 'c', 'd'])
569569
assert_array_equal(actual, desired)
570570

571+
def test_choice_multidimensional_default_axis(self):
572+
random.seed(self.seed)
573+
actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 3)
574+
desired = np.array([[4, 5], [6, 7], [4, 5]])
575+
assert_array_equal(actual, desired)
576+
577+
def test_choice_multidimensional_custom_axis(self):
578+
random.seed(self.seed)
579+
actual = random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 1, axis=1)
580+
desired = np.array([[0], [2], [4], [6]])
581+
assert_array_equal(actual, desired)
582+
571583
def test_choice_exceptions(self):
572584
sample = random.choice
573585
assert_raises(ValueError, sample, -1, 3)
574586
assert_raises(ValueError, sample, 3., 3)
575-
assert_raises(ValueError, sample, [[1, 2], [3, 4]], 3)
576587
assert_raises(ValueError, sample, [], 3)
577588
assert_raises(ValueError, sample, [1, 2, 3, 4], 3,
578589
p=[[0.25, 0.25], [0.25, 0.25]])

0 commit comments

Comments
 (0)