@@ -577,7 +577,7 @@ cdef class RandomGenerator:
577
577
return self .randint (0 , 4294967296 , size = n_uint32 , dtype = np .uint32 ).tobytes ()[:length ]
578
578
579
579
@cython .wraparound (True )
580
- def choice (self , a , size = None , replace = True , p = None ):
580
+ def choice (self , a , size = None , replace = True , p = None , axis = 0 ):
581
581
"""
582
582
choice(a, size=None, replace=True, p=None)
583
583
@@ -600,6 +600,9 @@ cdef class RandomGenerator:
600
600
The probabilities associated with each entry in a.
601
601
If not given the sample assumes a uniform distribution over all
602
602
entries in a.
603
+ axis : int, optional
604
+ The axis along which the selection is performed. The default, 0,
605
+ selects by row.
603
606
604
607
Returns
605
608
-------
@@ -609,11 +612,11 @@ cdef class RandomGenerator:
609
612
Raises
610
613
------
611
614
ValueError
612
- If a is an int and less than zero, if a or p are not 1-dimensional,
613
- if a is an array-like of size 0, if p is not a vector of
615
+ If a is an int and less than zero, if p is not 1-dimensional, if
616
+ a is array-like with a size 0, if p is not a vector of
614
617
probabilities, if a and p have different lengths, or if
615
618
replace=False and the sample size is greater than the population
616
- size
619
+ size.
617
620
618
621
See Also
619
622
--------
@@ -665,11 +668,9 @@ cdef class RandomGenerator:
665
668
raise ValueError ("a must be 1-dimensional or an integer" )
666
669
if pop_size <= 0 and np .prod (size ) != 0 :
667
670
raise ValueError ("a must be greater than 0 unless no samples are taken" )
668
- elif a .ndim != 1 :
669
- raise ValueError ("a must be 1-dimensional" )
670
671
else :
671
- pop_size = a .shape [0 ]
672
- if pop_size is 0 and np .prod (size ) != 0 :
672
+ pop_size = a .shape [axis ]
673
+ if pop_size == 0 and np .prod (size ) != 0 :
673
674
raise ValueError ("'a' cannot be empty unless no samples are taken" )
674
675
675
676
if p is not None :
@@ -761,7 +762,7 @@ cdef class RandomGenerator:
761
762
res [()] = a [idx ]
762
763
return res
763
764
764
- return a [ idx ]
765
+ return a . take ( idx , axis = axis )
765
766
766
767
def uniform (self , low = 0.0 , high = 1.0 , size = None ):
767
768
"""
0 commit comments