Skip to content

Commit 732b62a

Browse files
committed
BUG/ENH: Fix zipf changes missed in NumPy
Fix zipf changes missed in NumPy Enable 0 as valid input for hypergeometric
1 parent bbdf80d commit 732b62a

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

numpy/random/generator.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3639,7 +3639,7 @@ cdef class RandomGenerator:
36393639
x.shape = tuple(final_shape)
36403640
return x
36413641

3642-
def multinomial(self, np.npy_intp n, object pvals, size=None):
3642+
def multinomial(self, object n, object pvals, size=None):
36433643
"""
36443644
multinomial(n, pvals, size=None)
36453645

numpy/random/src/distributions/distributions.c

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,25 +1048,31 @@ int64_t random_geometric(brng_t *brng_state, double p) {
10481048
}
10491049

10501050
int64_t random_zipf(brng_t *brng_state, double a) {
1051-
double T, U, V;
1052-
int64_t X;
10531051
double am1, b;
10541052

10551053
am1 = a - 1.0;
10561054
b = pow(2.0, am1);
1057-
do {
1058-
U = 1.0 - next_double(brng_state);
1059-
V = next_double(brng_state);
1060-
X = (int64_t)floor(pow(U, -1.0 / am1));
1061-
/* The real result may be above what can be represented in a int64.
1062-
* It will get casted to -sys.maxint-1. Since this is
1063-
* a straightforward rejection algorithm, we can just reject this value
1064-
* in the rejection condition below. This function then models a Zipf
1055+
while (1) {
1056+
double T, U, V, X;
1057+
1058+
U = 1.0 - random_double(brng_state);
1059+
V = random_double(brng_state);
1060+
X = floor(pow(U, -1.0 / am1));
1061+
/*
1062+
* The real result may be above what can be represented in a signed
1063+
* long. Since this is a straightforward rejection algorithm, we can
1064+
* just reject this value. This function then models a Zipf
10651065
* distribution truncated to sys.maxint.
10661066
*/
1067+
if (X > LONG_MAX || X < 1.0) {
1068+
continue;
1069+
}
1070+
10671071
T = pow(1.0 + 1.0 / X, am1);
1068-
} while (((V * X * (T - 1.0) / (b - 1.0)) > (T / b)) || X < 1);
1069-
return X;
1072+
if (V * X * (T - 1.0) / (b - 1.0) <= T / b) {
1073+
return (long)X;
1074+
}
1075+
}
10701076
}
10711077

10721078
double random_triangular(brng_t *brng_state, double left, double mode,

numpy/random/tests/test_generator_mt19937.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,7 @@ def test_geometric_exceptions(self):
804804
assert_raises(ValueError, random.geometric, [1.1] * 10)
805805
assert_raises(ValueError, random.geometric, -0.1)
806806
assert_raises(ValueError, random.geometric, [-0.1] * 10)
807-
with suppress_warnings() as sup:
808-
sup.record(RuntimeWarning)
807+
with np.errstate(invalid='ignore'):
809808
assert_raises(ValueError, random.geometric, np.nan)
810809
assert_raises(ValueError, random.geometric, [np.nan] * 10)
811810

@@ -888,8 +887,7 @@ def test_logseries(self):
888887
assert_array_equal(actual, desired)
889888

890889
def test_logseries_exceptions(self):
891-
with suppress_warnings() as sup:
892-
sup.record(RuntimeWarning)
890+
with np.errstate(invalid='ignore'):
893891
assert_raises(ValueError, random.logseries, np.nan)
894892
assert_raises(ValueError, random.logseries, [np.nan] * 10)
895893

@@ -964,8 +962,7 @@ def test_negative_binomial(self):
964962
assert_array_equal(actual, desired)
965963

966964
def test_negative_binomial_exceptions(self):
967-
with suppress_warnings() as sup:
968-
sup.record(RuntimeWarning)
965+
with np.errstate(invalid='ignore'):
969966
assert_raises(ValueError, random.negative_binomial, 100, np.nan)
970967
assert_raises(ValueError, random.negative_binomial, 100,
971968
[np.nan] * 10)
@@ -1046,8 +1043,7 @@ def test_poisson_exceptions(self):
10461043
assert_raises(ValueError, random.poisson, [lamneg] * 10)
10471044
assert_raises(ValueError, random.poisson, lambig)
10481045
assert_raises(ValueError, random.poisson, [lambig] * 10)
1049-
with suppress_warnings() as sup:
1050-
sup.record(RuntimeWarning)
1046+
with np.errstate(invalid='ignore'):
10511047
assert_raises(ValueError, random.poisson, np.nan)
10521048
assert_raises(ValueError, random.poisson, [np.nan] * 10)
10531049

0 commit comments

Comments
 (0)