Skip to content

Commit 09bdc6e

Browse files
committed
BUG/ENH: Fix zipf changes missed in NumPy
Fix zipf changes missed in NumPy Enable 0 as valid input for hypergeometric
1 parent d8c7713 commit 09bdc6e

File tree

5 files changed

+39
-25
lines changed

5 files changed

+39
-25
lines changed

numpy/random/generator.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3642,7 +3642,7 @@ cdef class RandomGenerator:
36423642
x.shape = tuple(final_shape)
36433643
return x
36443644

3645-
def multinomial(self, np.npy_intp n, object pvals, size=None):
3645+
def multinomial(self, object n, object pvals, size=None):
36463646
"""
36473647
multinomial(n, pvals, size=None)
36483648
@@ -3741,11 +3741,12 @@ cdef class RandomGenerator:
37413741
on = <np.ndarray>np.PyArray_FROM_OTF(n, np.NPY_INT64, np.NPY_ALIGNED)
37423742
parr = <np.ndarray>np.PyArray_FROM_OTF(pvals, np.NPY_DOUBLE, np.NPY_ALIGNED)
37433743
pix = <double*>np.PyArray_DATA(parr)
3744-
3744+
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
37453745
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
37463746
raise ValueError("sum(pvals[:-1]) > 1.0")
37473747

37483748
if np.PyArray_NDIM(on) != 0: # vector
3749+
check_array_constraint(on, 'n', CONS_NON_NEGATIVE)
37493750
if size is None:
37503751
it = np.PyArray_MultiIterNew1(on)
37513752
else:
@@ -3779,6 +3780,7 @@ cdef class RandomGenerator:
37793780
mnix = <int64_t*>np.PyArray_DATA(mnarr)
37803781
sz = np.PyArray_SIZE(mnarr)
37813782
ni = n
3783+
check_constraint(ni, 'n', CONS_NON_NEGATIVE)
37823784
offset = 0
37833785
with self.lock, nogil:
37843786
for i in range(sz // d):

numpy/random/mtrand.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3799,7 +3799,7 @@ cdef class RandomState:
37993799
d = len(pvals)
38003800
parr = <np.ndarray>np.PyArray_FROM_OTF(pvals, np.NPY_DOUBLE, np.NPY_ALIGNED)
38013801
pix = <double*>np.PyArray_DATA(parr)
3802-
3802+
check_array_constraint(parr, 'pvals', CONS_BOUNDED_0_1)
38033803
if kahan_sum(pix, d-1) > (1.0 + 1e-12):
38043804
raise ValueError("sum(pvals[:-1]) > 1.0")
38053805

@@ -3816,6 +3816,7 @@ cdef class RandomState:
38163816
mnix = <int64_t*>np.PyArray_DATA(mnarr)
38173817
sz = np.PyArray_SIZE(mnarr)
38183818
ni = n
3819+
check_constraint(ni, 'n', CONS_NON_NEGATIVE)
38193820
offset = 0
38203821
with self.lock, nogil:
38213822
for i in range(sz // d):

numpy/random/src/distributions/distributions.c

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,25 +1048,31 @@ int64_t random_geometric(brng_t *brng_state, double p) {
10481048
}
10491049

10501050
int64_t random_zipf(brng_t *brng_state, double a) {
1051-
double T, U, V;
1052-
int64_t X;
10531051
double am1, b;
10541052

10551053
am1 = a - 1.0;
10561054
b = pow(2.0, am1);
1057-
do {
1058-
U = 1.0 - next_double(brng_state);
1059-
V = next_double(brng_state);
1060-
X = (int64_t)floor(pow(U, -1.0 / am1));
1061-
/* The real result may be above what can be represented in a int64.
1062-
* It will get casted to -sys.maxint-1. Since this is
1063-
* a straightforward rejection algorithm, we can just reject this value
1064-
* in the rejection condition below. This function then models a Zipf
1055+
while (1) {
1056+
double T, U, V, X;
1057+
1058+
U = 1.0 - random_double(brng_state);
1059+
V = random_double(brng_state);
1060+
X = floor(pow(U, -1.0 / am1));
1061+
/*
1062+
* The real result may be above what can be represented in a signed
1063+
* long. Since this is a straightforward rejection algorithm, we can
1064+
* just reject this value. This function then models a Zipf
10651065
* distribution truncated to sys.maxint.
10661066
*/
1067+
if (X > LONG_MAX || X < 1.0) {
1068+
continue;
1069+
}
1070+
10671071
T = pow(1.0 + 1.0 / X, am1);
1068-
} while (((V * X * (T - 1.0) / (b - 1.0)) > (T / b)) || X < 1);
1069-
return X;
1072+
if (V * X * (T - 1.0) / (b - 1.0) <= T / b) {
1073+
return (long)X;
1074+
}
1075+
}
10701076
}
10711077

10721078
double random_triangular(brng_t *brng_state, double left, double mode,

numpy/random/tests/test_generator_mt19937.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def test_size(self):
9090

9191
def test_invalid_prob(self):
9292
assert_raises(ValueError, random.multinomial, 100, [1.1, 0.2])
93+
assert_raises(ValueError, random.multinomial, 100, [-.1, 0.9])
94+
95+
def test_invalid_n(self):
96+
assert_raises(ValueError, random.multinomial, -1, [0.8, 0.2])
97+
assert_raises(ValueError, random.multinomial, [-1] * 10, [0.8, 0.2])
9398

9499

95100
class TestSetState(object):
@@ -804,8 +809,7 @@ def test_geometric_exceptions(self):
804809
assert_raises(ValueError, random.geometric, [1.1] * 10)
805810
assert_raises(ValueError, random.geometric, -0.1)
806811
assert_raises(ValueError, random.geometric, [-0.1] * 10)
807-
with suppress_warnings() as sup:
808-
sup.record(RuntimeWarning)
812+
with np.errstate(invalid='ignore'):
809813
assert_raises(ValueError, random.geometric, np.nan)
810814
assert_raises(ValueError, random.geometric, [np.nan] * 10)
811815

@@ -888,8 +892,7 @@ def test_logseries(self):
888892
assert_array_equal(actual, desired)
889893

890894
def test_logseries_exceptions(self):
891-
with suppress_warnings() as sup:
892-
sup.record(RuntimeWarning)
895+
with np.errstate(invalid='ignore'):
893896
assert_raises(ValueError, random.logseries, np.nan)
894897
assert_raises(ValueError, random.logseries, [np.nan] * 10)
895898

@@ -964,8 +967,7 @@ def test_negative_binomial(self):
964967
assert_array_equal(actual, desired)
965968

966969
def test_negative_binomial_exceptions(self):
967-
with suppress_warnings() as sup:
968-
sup.record(RuntimeWarning)
970+
with np.errstate(invalid='ignore'):
969971
assert_raises(ValueError, random.negative_binomial, 100, np.nan)
970972
assert_raises(ValueError, random.negative_binomial, 100,
971973
[np.nan] * 10)
@@ -1046,8 +1048,7 @@ def test_poisson_exceptions(self):
10461048
assert_raises(ValueError, random.poisson, [lamneg] * 10)
10471049
assert_raises(ValueError, random.poisson, lambig)
10481050
assert_raises(ValueError, random.poisson, [lambig] * 10)
1049-
with suppress_warnings() as sup:
1050-
sup.record(RuntimeWarning)
1051+
with np.errstate(invalid='ignore'):
10511052
assert_raises(ValueError, random.poisson, np.nan)
10521053
assert_raises(ValueError, random.poisson, [np.nan] * 10)
10531054

@@ -1850,7 +1851,7 @@ def test_logseries(self):
18501851
assert_raises(ValueError, logseries, bad_p_two * 3)
18511852

18521853
def test_multinomial(self):
1853-
random.seed(self.seed)
1854+
random.brng.seed(self.seed)
18541855
actual = random.multinomial([5, 20], [1 / 6.] * 6, size=(3, 2))
18551856
desired = np.array([[[1, 1, 1, 1, 0, 1],
18561857
[4, 5, 1, 4, 3, 3]],
@@ -1860,7 +1861,7 @@ def test_multinomial(self):
18601861
[3, 2, 3, 4, 2, 6]]], dtype=np.int64)
18611862
assert_array_equal(actual, desired)
18621863

1863-
random.seed(self.seed)
1864+
random.brng.seed(self.seed)
18641865
actual = random.multinomial([5, 20], [1 / 6.] * 6)
18651866
desired = np.array([[1, 1, 1, 1, 0, 1],
18661867
[4, 5, 1, 4, 3, 3]], dtype=np.int64)

numpy/random/tests/test_randomstate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def test_size(self):
112112

113113
def test_invalid_prob(self):
114114
assert_raises(ValueError, random.multinomial, 100, [1.1, 0.2])
115+
assert_raises(ValueError, random.multinomial, 100, [-.1, 0.9])
116+
117+
def test_invalid_n(self):
118+
assert_raises(ValueError, random.multinomial, -1, [0.8, 0.2])
115119

116120

117121
class TestSetState(object):

0 commit comments

Comments
 (0)