Skip to content

Commit d8c7713

Browse files
committed
ENH: Extend multinomial
Extend multinomial to allow broadcasting
1 parent 610b5e0 commit d8c7713

File tree

6 files changed

+97
-40
lines changed

6 files changed

+97
-40
lines changed

numpy/random/distributions.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,6 @@ cdef extern from "src/distributions/distributions.h":
144144
np.npy_bool off, np.npy_bool rng, np.npy_intp cnt,
145145
bint use_masked,
146146
np.npy_bool *out) nogil
147+
148+
void random_multinomial(brng_t *brng_state, int64_t n, int64_t *mnix,
149+
double *pix, np.npy_intp d, binomial_t *binomial) nogil

numpy/random/generator.pyx

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3658,7 +3658,7 @@ cdef class RandomGenerator:
36583658
36593659
Parameters
36603660
----------
3661-
n : int
3661+
n : int or array-like of ints
36623662
Number of experiments.
36633663
pvals : sequence of floats, length p
36643664
Probabilities of each of the ``p`` different outcomes. These
@@ -3697,6 +3697,18 @@ cdef class RandomGenerator:
36973697
For the first run, we threw 3 times 1, 4 times 2, etc. For the second,
36983698
we threw 2 times 1, 4 times 2, etc.
36993699
3700+
Now, do one experiment throwing the dice 10 time, and 10 times again,
3701+
and another throwing the dice 20 times, and 20 times again:
3702+
3703+
>>> np.random.multinomial([[10], [20]], [1/6.]*6, size=2)
3704+
array([[[2, 4, 0, 1, 2, 1],
3705+
[1, 3, 0, 3, 1, 2]],
3706+
[[1, 4, 4, 4, 4, 3],
3707+
[3, 3, 2, 5, 5, 2]]]) # random
3708+
3709+
The first array shows the outcomes of throwing the dice 10 times, and
3710+
the second shows the outcomes from throwing the dice 20 times.
3711+
37003712
A loaded die is more likely to land on number 6:
37013713
37023714
>>> np.random.multinomial(100, [1/7.]*5 + [2/7.])
@@ -3717,19 +3729,43 @@ cdef class RandomGenerator:
37173729
array([100, 0])
37183730
37193731
"""
3720-
cdef np.npy_intp d, i, j, dn, sz
3721-
cdef np.ndarray parr "arrayObject_parr", mnarr "arrayObject_mnarr"
3732+
3733+
cdef np.npy_intp d, i, sz, offset
3734+
cdef np.ndarray parr, mnarr, on, temp_arr
37223735
cdef double *pix
37233736
cdef int64_t *mnix
3724-
cdef double Sum
3737+
cdef int64_t ni
3738+
cdef np.broadcast it
37253739

37263740
d = len(pvals)
3741+
on = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED)
37273742
parr = <np.ndarray>np.PyArray_FROM_OTF(pvals, np.NPY_DOUBLE, np.NPY_ALIGNED)
37283743
pix = <double*>np.PyArray_DATA(parr)
37293744

37303745
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
37313746
raise ValueError("sum(pvals[:-1]) > 1.0")
37323747

3748+
if np.PyArray_NDIM(on) != 0: # vector
3749+
if size is None:
3750+
it = np.PyArray_MultiIterNew1(on)
3751+
else:
3752+
temp = np.empty(size, dtype=np.int8)
3753+
temp_arr = <np.ndarray>temp
3754+
it = np.PyArray_MultiIterNew2(on, temp_arr)
3755+
shape = it.shape + (d,)
3756+
multin = np.zeros(shape, dtype=np.int64)
3757+
mnarr = <np.ndarray>multin
3758+
mnix = <int64_t*>np.PyArray_DATA(mnarr)
3759+
offset = 0
3760+
sz = it.size
3761+
with self.lock, nogil:
3762+
for i in range(sz):
3763+
ni = (<int64_t*>np.PyArray_MultiIter_DATA(it, 0))[0]
3764+
random_multinomial(self._brng, ni, &mnix[offset], pix, d, self._binomial)
3765+
offset += d
3766+
np.PyArray_MultiIter_NEXT(it)
3767+
return multin
3768+
37333769
if size is None:
37343770
shape = (d,)
37353771
else:
@@ -3742,23 +3778,12 @@ cdef class RandomGenerator:
37423778
mnarr = <np.ndarray>multin
37433779
mnix = <int64_t*>np.PyArray_DATA(mnarr)
37443780
sz = np.PyArray_SIZE(mnarr)
3745-
3781+
ni = n
3782+
offset = 0
37463783
with self.lock, nogil:
3747-
i = 0
3748-
while i < sz:
3749-
Sum = 1.0
3750-
dn = n
3751-
for j in range(d-1):
3752-
mnix[i+j] = random_binomial(self._brng, pix[j]/Sum, dn,
3753-
self._binomial)
3754-
dn = dn - mnix[i+j]
3755-
if dn <= 0:
3756-
break
3757-
Sum = Sum - pix[j]
3758-
if dn > 0:
3759-
mnix[i+d-1] = dn
3760-
3761-
i = i + d
3784+
for i in range(sz // d):
3785+
random_multinomial(self._brng, ni, &mnix[offset], pix, d, self._binomial)
3786+
offset += d
37623787

37633788
return multin
37643789

numpy/random/mtrand.pyx

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3790,11 +3790,11 @@ cdef class RandomState:
37903790
array([100, 0])
37913791
37923792
"""
3793-
cdef np.npy_intp d, i, j, dn, sz
3794-
cdef np.ndarray parr "arrayObject_parr", mnarr "arrayObject_mnarr"
3793+
cdef np.npy_intp d, i, sz, offset
3794+
cdef np.ndarray parr, mnarr
37953795
cdef double *pix
37963796
cdef int64_t *mnix
3797-
cdef double Sum
3797+
cdef int64_t ni
37983798

37993799
d = len(pvals)
38003800
parr = <np.ndarray>np.PyArray_FROM_OTF(pvals, np.NPY_DOUBLE, np.NPY_ALIGNED)
@@ -3815,23 +3815,12 @@ cdef class RandomState:
38153815
mnarr = <np.ndarray>multin
38163816
mnix = <int64_t*>np.PyArray_DATA(mnarr)
38173817
sz = np.PyArray_SIZE(mnarr)
3818-
3818+
ni = n
3819+
offset = 0
38193820
with self.lock, nogil:
3820-
i = 0
3821-
while i < sz:
3822-
Sum = 1.0
3823-
dn = n
3824-
for j in range(d-1):
3825-
mnix[i+j] = random_binomial(self._brng, pix[j]/Sum, dn,
3826-
self._binomial)
3827-
dn = dn - mnix[i+j]
3828-
if dn <= 0:
3829-
break
3830-
Sum = Sum - pix[j]
3831-
if dn > 0:
3832-
mnix[i+d-1] = dn
3833-
3834-
i = i + d
3821+
for i in range(sz // d):
3822+
random_multinomial(self._brng, ni, &mnix[offset], pix, d, self._binomial)
3823+
offset += d
38353824

38363825
return multin
38373826

numpy/random/src/distributions/distributions.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,8 +1179,10 @@ int64_t random_hypergeometric(brng_t *brng_state, int64_t good, int64_t bad,
11791179
int64_t sample) {
11801180
if (sample > 10) {
11811181
return random_hypergeometric_hrua(brng_state, good, bad, sample);
1182-
} else {
1182+
} else if (sample > 0) {
11831183
return random_hypergeometric_hyp(brng_state, good, bad, sample);
1184+
} else {
1185+
return 0;
11841186
}
11851187
}
11861188

@@ -1809,3 +1811,21 @@ void random_bounded_bool_fill(brng_t *brng_state, npy_bool off, npy_bool rng,
18091811
out[i] = buffered_bounded_bool(brng_state, off, rng, mask, &bcnt, &buf);
18101812
}
18111813
}
1814+
1815+
void random_multinomial(brng_t *brng_state, int64_t n, int64_t *mnix,
1816+
double *pix, npy_intp d, binomial_t *binomial) {
1817+
double remaining_p = 1.0;
1818+
npy_intp j;
1819+
int64_t dn = n;
1820+
for (j = 0; j < (d - 1); j++) {
1821+
mnix[j] = random_binomial(brng_state, pix[j] / remaining_p, dn, binomial);
1822+
dn = dn - mnix[j];
1823+
if (dn <= 0) {
1824+
break;
1825+
}
1826+
remaining_p -= pix[j];
1827+
if (dn > 0) {
1828+
mnix[d - 1] = dn;
1829+
}
1830+
}
1831+
}

numpy/random/src/distributions/distributions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,7 @@ DECLDIR void random_bounded_bool_fill(brng_t *brng_state, npy_bool off,
217217
npy_bool rng, npy_intp cnt,
218218
bool use_masked, npy_bool *out);
219219

220+
DECLDIR void random_multinomial(brng_t *brng_state, int64_t n, int64_t *mnix,
221+
double *pix, npy_intp d, binomial_t *binomial);
222+
220223
#endif

numpy/random/tests/test_generator_mt19937.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,23 @@ def test_logseries(self):
18491849
assert_raises(ValueError, logseries, bad_p_one * 3)
18501850
assert_raises(ValueError, logseries, bad_p_two * 3)
18511851

1852+
def test_multinomial(self):
1853+
random.seed(self.seed)
1854+
actual = random.multinomial([5, 20], [1 / 6.] * 6, size=(3, 2))
1855+
desired = np.array([[[1, 1, 1, 1, 0, 1],
1856+
[4, 5, 1, 4, 3, 3]],
1857+
[[1, 1, 1, 0, 0, 2],
1858+
[2, 0, 4, 3, 7, 4]],
1859+
[[1, 2, 0, 0, 2, 2],
1860+
[3, 2, 3, 4, 2, 6]]], dtype=np.int64)
1861+
assert_array_equal(actual, desired)
1862+
1863+
random.seed(self.seed)
1864+
actual = random.multinomial([5, 20], [1 / 6.] * 6)
1865+
desired = np.array([[1, 1, 1, 1, 0, 1],
1866+
[4, 5, 1, 4, 3, 3]], dtype=np.int64)
1867+
assert_array_equal(actual, desired)
1868+
18521869

18531870
class TestThread(object):
18541871
# make sure each state produces the same sequence even in threads

0 commit comments

Comments
 (0)