@@ -515,9 +515,9 @@ cdef class RandomGenerator:
515
515
return self .randint (0 , 4294967296 , size = n_uint32 , dtype = np .uint32 ).tobytes ()[:length ]
516
516
517
517
@cython .wraparound (True )
518
- def choice (self , a , size = None , replace = True , p = None ):
518
+ def choice (self , a , size = None , replace = True , p = None , axis = 0 ):
519
519
"""
520
- choice(a, size=None, replace=True, p=None)
520
+ choice(a, size=None, replace=True, p=None, axis=0):
521
521
522
522
Generates a random sample from a given 1-D array
523
523
@@ -538,6 +538,9 @@ cdef class RandomGenerator:
538
538
The probabilities associated with each entry in a.
539
539
If not given the sample assumes a uniform distribution over all
540
540
entries in a.
541
+ axis : int, optional
542
+ The axis along which the selection is performed. The default, 0,
543
+ selects by row.
541
544
542
545
Returns
543
546
-------
@@ -547,11 +550,11 @@ cdef class RandomGenerator:
547
550
Raises
548
551
------
549
552
ValueError
550
- If a is an int and less than zero, if a or p are not 1-dimensional,
551
- if a is an array-like of size 0, if p is not a vector of
553
+ If a is an int and less than zero, if p is not 1-dimensional, if
554
+ a is array-like with a size 0, if p is not a vector of
552
555
probabilities, if a and p have different lengths, or if
553
556
replace=False and the sample size is greater than the population
554
- size
557
+ size.
555
558
556
559
See Also
557
560
--------
@@ -592,7 +595,14 @@ cdef class RandomGenerator:
592
595
dtype='<U11')
593
596
594
597
"""
595
-
598
+ cdef char * idx_ptr
599
+ cdef int64_t buf
600
+ cdef char * buf_ptr
601
+
602
+ cdef set idx_set
603
+ cdef int64_t val , t , loc , size_i , pop_size_i
604
+ cdef int64_t * idx_data
605
+ cdef np .npy_intp j
596
606
# Format and Verify input
597
607
a = np .array (a , copy = False )
598
608
if a .ndim == 0 :
@@ -603,11 +613,9 @@ cdef class RandomGenerator:
603
613
raise ValueError ("a must be 1-dimensional or an integer" )
604
614
if pop_size <= 0 and np .prod (size ) != 0 :
605
615
raise ValueError ("a must be greater than 0 unless no samples are taken" )
606
- elif a .ndim != 1 :
607
- raise ValueError ("a must be 1-dimensional" )
608
616
else :
609
- pop_size = a .shape [0 ]
610
- if pop_size is 0 and np .prod (size ) != 0 :
617
+ pop_size = a .shape [axis ]
618
+ if pop_size == 0 and np .prod (size ) != 0 :
611
619
raise ValueError ("'a' cannot be empty unless no samples are taken" )
612
620
613
621
if p is not None :
@@ -677,7 +685,39 @@ cdef class RandomGenerator:
677
685
n_uniq += new .size
678
686
idx = found
679
687
else :
680
- idx = (self .permutation (pop_size )[:size ]).astype (np .int64 )
688
+ size_i = size
689
+ pop_size_i = pop_size
690
+ # This is a heuristic tuning. should be improvable
691
+ if pop_size_i > 200 and (size > 200 or size > (10 * pop_size // size )):
692
+ # Tail shuffle size elements
693
+ idx = np .arange (pop_size , dtype = np .int64 )
694
+ idx_ptr = np .PyArray_BYTES (< np .ndarray > idx )
695
+ buf_ptr = < char * > & buf
696
+ self ._shuffle_raw (pop_size_i , max (pop_size_i - size_i ,1 ),
697
+ 8 , 8 , idx_ptr , buf_ptr )
698
+ # Copy to allow potentially large array backing idx to be gc
699
+ idx = idx [(pop_size - size ):].copy ()
700
+ else :
701
+ # Floyds's algorithm with precomputed indices
702
+ # Worst case, O(n**2) when size is close to pop_size
703
+ idx = np .empty (size , dtype = np .int64 )
704
+ idx_data = < int64_t * > np .PyArray_DATA (< np .ndarray > idx )
705
+ idx_set = set ()
706
+ loc = 0
707
+ # Sample indices with one pass to avoid reacquiring the lock
708
+ with self .lock :
709
+ for j in range (pop_size_i - size_i , pop_size_i ):
710
+ idx_data [loc ] = random_interval (self ._brng , j )
711
+ loc += 1
712
+ loc = 0
713
+ while len (idx_set ) < size_i :
714
+ for j in range (pop_size_i - size_i , pop_size_i ):
715
+ if idx_data [loc ] not in idx_set :
716
+ val = idx_data [loc ]
717
+ else :
718
+ idx_data [loc ] = val = j
719
+ idx_set .add (val )
720
+ loc += 1
681
721
if shape is not None :
682
722
idx .shape = shape
683
723
@@ -699,7 +739,9 @@ cdef class RandomGenerator:
699
739
res [()] = a [idx ]
700
740
return res
701
741
702
- return a [idx ]
742
+ # asarray downcasts on 32-bit platforms, always safe
743
+ # no-op on 64-bit platforms
744
+ return a .take (np .asarray (idx , dtype = np .intp ), axis = axis )
703
745
704
746
def uniform (self , low = 0.0 , high = 1.0 , size = None ):
705
747
"""
@@ -3971,9 +4013,9 @@ cdef class RandomGenerator:
3971
4013
# the most common case, yielding a ~33% performance improvement.
3972
4014
# Note that apparently, only one branch can ever be specialized.
3973
4015
if itemsize == sizeof (np .npy_intp ):
3974
- self ._shuffle_raw (n , sizeof (np .npy_intp ), stride , x_ptr , buf_ptr )
4016
+ self ._shuffle_raw (n , 1 , sizeof (np .npy_intp ), stride , x_ptr , buf_ptr )
3975
4017
else :
3976
- self ._shuffle_raw (n , itemsize , stride , x_ptr , buf_ptr )
4018
+ self ._shuffle_raw (n , 1 , itemsize , stride , x_ptr , buf_ptr )
3977
4019
elif isinstance (x , np .ndarray ) and x .ndim and x .size :
3978
4020
buf = np .empty_like (x [0 , ...])
3979
4021
with self .lock :
@@ -3992,10 +4034,29 @@ cdef class RandomGenerator:
3992
4034
j = random_interval (self ._brng , i )
3993
4035
x [i ], x [j ] = x [j ], x [i ]
3994
4036
3995
- cdef inline _shuffle_raw (self , np .npy_intp n , np .npy_intp itemsize ,
3996
- np .npy_intp stride , char * data , char * buf ):
4037
+ cdef inline _shuffle_raw (self , np .npy_intp n , np .npy_intp first ,
4038
+ np .npy_intp itemsize , np .npy_intp stride ,
4039
+ char * data , char * buf ):
4040
+ """
4041
+ Parameters
4042
+ ----------
4043
+ n
4044
+ Number of elements in data
4045
+ first
4046
+ First observation to shuffle. Shuffles n-1,
4047
+ n-2, ..., first, so that when first=1 the entire
4048
+ array is shuffled
4049
+ itemsize
4050
+ Size in bytes of item
4051
+ stride
4052
+ Array stride
4053
+ data
4054
+ Location of data
4055
+ buf
4056
+ Location of buffer (itemsize)
4057
+ """
3997
4058
cdef np .npy_intp i , j
3998
- for i in reversed (range (1 , n )):
4059
+ for i in reversed (range (first , n )):
3999
4060
j = random_interval (self ._brng , i )
4000
4061
string .memcpy (buf , data + j * stride , itemsize )
4001
4062
string .memcpy (data + j * stride , data + i * stride , itemsize )
0 commit comments