Skip to content

Commit 8e9aee2

Browse files
authored
Merge pull request #14 from bashtage/small-sync
ENH: Add fast path for randint broadcasting
2 parents 563c258 + 64b1f62 commit 8e9aee2

File tree

4 files changed

+63
-18
lines changed

4 files changed

+63
-18
lines changed

numpy/random/bounded_integers.pyx.in

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,29 @@ cdef object _rand_{{nptype}}_broadcast(object low, object high, object size,
125125

126126
if np.any(np.less(low_arr, {{lb}})):
127127
raise ValueError('low is out of bounds for {{nptype}}')
128-
129-
highm1_arr = <np.ndarray>np.empty_like(high_arr, dtype=np.{{nptype}})
130-
highm1_data = <{{nptype}}_t *>np.PyArray_DATA(highm1_arr)
131-
cnt = np.PyArray_SIZE(high_arr)
132-
flat = high_arr.flat
133-
for i in range(cnt):
134-
# Subtract 1 since generator produces values on the closed int [off, off+rng]
135-
closed_upper = int(flat[i]) - 1
136-
if closed_upper > {{ub}}:
137-
raise ValueError('high is out of bounds for {{nptype}}')
138-
if closed_upper < {{lb}}:
128+
dt = high_arr.dtype
129+
if np.issubdtype(dt, np.integer):
130+
# Avoid object dtype path if already an integer
131+
if np.any(np.less_equal(high_arr, {{lb}})):
139132
raise ValueError('low >= high')
140-
highm1_data[i] = <{{nptype}}_t>closed_upper
133+
high_m1 = high_arr - dt.type(1)
134+
if np.any(np.greater(high_m1, {{ub}})):
135+
raise ValueError('high is out of bounds for {{nptype}}')
136+
highm1_arr = <np.ndarray>np.PyArray_FROM_OTF(high_m1, np.{{npctype}}, np.NPY_ALIGNED | np.NPY_FORCECAST)
137+
else:
138+
# If input is object or a floating type
139+
highm1_arr = <np.ndarray>np.empty_like(high_arr, dtype=np.{{nptype}})
140+
highm1_data = <{{nptype}}_t *>np.PyArray_DATA(highm1_arr)
141+
cnt = np.PyArray_SIZE(high_arr)
142+
flat = high_arr.flat
143+
for i in range(cnt):
144+
# Subtract 1 since generator produces values on the closed int [off, off+rng]
145+
closed_upper = int(flat[i]) - 1
146+
if closed_upper > {{ub}}:
147+
raise ValueError('high is out of bounds for {{nptype}}')
148+
if closed_upper < {{lb}}:
149+
raise ValueError('low >= high')
150+
highm1_data[i] = <{{nptype}}_t>closed_upper
141151

142152
if np.any(np.greater(low_arr, highm1_arr)):
143153
raise ValueError('low >= high')

numpy/random/generator.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ cdef class RandomGenerator:
349349
--------
350350
randint : Uniform sampling over a given half-open interval of integers.
351351
random_integers : Uniform sampling over a given closed interval of
352-
integers.
352+
integers.
353353
354354
Examples
355355
--------
@@ -413,12 +413,17 @@ cdef class RandomGenerator:
413413
`size`-shaped array of random integers from the appropriate
414414
distribution, or a single such random int if `size` not provided.
415415
416+
Notes
417+
-----
418+
When using broadcasting with uint64 dtypes, the maximum value (2**64)
419+
cannot be represented as a standard integer type. The high array (or
420+
low if high is None) must have object dtype, e.g., array([2**64]).
421+
416422
See Also
417423
--------
418-
random_integers : similar to `randint`, only for the closed
419-
interval [`low`, `high`], and 1 is the lowest value if `high` is
420-
omitted. In particular, this other one is the one to use to generate
421-
uniformly distributed discrete non-integers.
424+
random_integers : similar to `randint`, only for the closed interval
425+
[`low`, `high`], where 1 is the lowest value if
426+
`high` is omitted.
422427
423428
Examples
424429
--------

numpy/random/tests/test_direct.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def uniform32_from_uint64(x):
5151
out = (joined >> np.uint32(9)) * (1.0 / 2 ** 23)
5252
return out.astype(np.float32)
5353

54+
5455
def uniform32_from_uint53(x):
5556
x = np.uint64(x) >> np.uint64(16)
5657
x = np.uint32(x & np.uint64(0xffffffff))
@@ -92,6 +93,7 @@ def uniform_from_uint32(x):
9293
out[i // 2] = (a * 67108864.0 + b) / 9007199254740992.0
9394
return out
9495

96+
9597
def uniform_from_dsfmt(x):
9698
return x.view(np.double) - 1.0
9799

@@ -414,7 +416,8 @@ def test_seed_float_array(self):
414416
rs = RandomGenerator(self.brng(*self.data1['seed']))
415417
assert_raises(self.seed_error_type, rs.brng.seed, np.array([np.pi]))
416418
assert_raises(self.seed_error_type, rs.brng.seed, np.array([-np.pi]))
417-
assert_raises(self.seed_error_type, rs.brng.seed, np.array([np.pi, -np.pi]))
419+
assert_raises(self.seed_error_type, rs.brng.seed,
420+
np.array([np.pi, -np.pi]))
418421
assert_raises(self.seed_error_type, rs.brng.seed, np.array([0, np.pi]))
419422
assert_raises(self.seed_error_type, rs.brng.seed, [np.pi])
420423
assert_raises(self.seed_error_type, rs.brng.seed, [0, np.pi])

numpy/random/tests/test_generator_mt19937.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,33 @@ def test_repeatability_broadcasting(self):
335335

336336
assert_array_equal(val, val_bc)
337337

338+
def test_int64_uint64_broadcast_exceptions(self):
339+
configs = {np.uint64: ((0, 2**65), (-1, 2**62), (10, 9), (0, 0)),
340+
np.int64: ((0, 2**64), (-(2**64), 2**62), (10, 9), (0, 0),
341+
(-2**63-1, -2**63-1))}
342+
for dtype in configs:
343+
for config in configs[dtype]:
344+
low, high = config
345+
low_a = np.array([[low]*10])
346+
high_a = np.array([high] * 10)
347+
assert_raises(ValueError, random.randint, low, high,
348+
dtype=dtype)
349+
assert_raises(ValueError, random.randint, low_a, high,
350+
dtype=dtype)
351+
assert_raises(ValueError, random.randint, low, high_a,
352+
dtype=dtype)
353+
assert_raises(ValueError, random.randint, low_a, high_a,
354+
dtype=dtype)
355+
356+
low_o = np.array([[low]*10], dtype=np.object)
357+
high_o = np.array([high] * 10, dtype=np.object)
358+
assert_raises(ValueError, random.randint, low_o, high,
359+
dtype=dtype)
360+
assert_raises(ValueError, random.randint, low, high_o,
361+
dtype=dtype)
362+
assert_raises(ValueError, random.randint, low_o, high_o,
363+
dtype=dtype)
364+
338365
def test_int64_uint64_corner_case(self):
339366
# When stored in Numpy arrays, `lbnd` is casted
340367
# as np.int64, and `ubnd` is casted as np.uint64.

0 commit comments

Comments
 (0)