Skip to content

ENH: Add multidimensional array support to numpy.random.choice #7810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions numpy/random/mtrand/mtrand.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ cdef class RandomState:
return bytestring


def choice(self, a, size=None, replace=True, p=None):
def choice(self, a, size=None, replace=True, p=None, axis=0):
"""
choice(a, size=None, replace=True, p=None)

Expand All @@ -1313,6 +1313,9 @@ cdef class RandomState:
The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
axis : int, optional
axis along which the selection is performed. The default, 0, will
select by row.

Returns
--------
Expand All @@ -1322,11 +1325,10 @@ cdef class RandomState:
Raises
-------
ValueError
If a is an int and less than zero, if a or p are not 1-dimensional,
if a is an array-like of size 0, if p is not a vector of
probabilities, if a and p have different lengths, or if
replace=False and the sample size is greater than the population
size
If a is an int and less than zero, if p is not 1-dimensional, if a
is an array-like of size 0, if p is not a vector of probabilities,
if a and p have different lengths, or if replace=False and the
sample size is greater than the population size

See Also
---------
Expand Down Expand Up @@ -1378,11 +1380,9 @@ cdef class RandomState:
raise ValueError("a must be 1-dimensional or an integer")
if pop_size <= 0:
raise ValueError("a must be greater than 0")
elif a.ndim != 1:
raise ValueError("a must be 1-dimensional")
else:
pop_size = a.shape[0]
if pop_size is 0:
pop_size = a.shape[axis]
if pop_size == 0:
raise ValueError("a must be non-empty")

if p is not None:
Expand Down Expand Up @@ -1455,7 +1455,7 @@ cdef class RandomState:
# In most cases a scalar will have been made an array
idx = idx.item(0)

#Use samples as indices for a if a is array-like
# Use samples as indices for a if a is array-like
if a.ndim == 0:
return idx

Expand All @@ -1469,7 +1469,7 @@ cdef class RandomState:
res[()] = a[idx]
return res

return a[idx]
return a.take(idx, axis=axis)


def uniform(self, low=0.0, high=1.0, size=None):
Expand Down
13 changes: 12 additions & 1 deletion numpy/random/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,22 @@ def test_choice_noninteger(self):
desired = np.array(['c', 'd', 'c', 'd'])
assert_array_equal(actual, desired)

def test_choice_multidimensional_default_axis(self):
np.random.seed(self.seed)
actual = np.random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 3)
desired = np.array([[4, 5], [6, 7], [4, 5]])
assert_array_equal(actual, desired)

def test_choice_multidimensional_custom_axis(self):
np.random.seed(self.seed)
actual = np.random.choice([[0, 1], [2, 3], [4, 5], [6, 7]], 1, axis=1)
desired = np.array([[0], [2], [4], [6]])
assert_array_equal(actual, desired)

def test_choice_exceptions(self):
sample = np.random.choice
assert_raises(ValueError, sample, -1, 3)
assert_raises(ValueError, sample, 3., 3)
assert_raises(ValueError, sample, [[1, 2], [3, 4]], 3)
assert_raises(ValueError, sample, [], 3)
assert_raises(ValueError, sample, [1, 2, 3, 4], 3,
p=[[0.25, 0.25], [0.25, 0.25]])
Expand Down