Skip to content

Commit 8b86eb5

Browse files
author
Diptorup Deb
committed
Porting the random.py from sklearn-numba-dpex.
1 parent 4708ac7 commit 8b86eb5

File tree

2 files changed

+344
-0
lines changed

2 files changed

+344
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-FileCopyrightText: 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# SPDX-FileCopyrightText: 2022 - 2023 Julien Jerphanion <[email protected]>
2+
# SPDX-FileCopyrightText: 2022 - 2023 Olivier Grisel <[email protected]>
3+
# SPDX-FileCopyrightText: 2022 - 2023 Franck Charras <[email protected]>
4+
# SPDX-FileCopyrightText: 2024 Intel Corporation
5+
#
6+
# SPDX-License-Identifier: BSD-3-Clause
7+
8+
"""
9+
This code is largely inspired from the numba.cuda.random module and the
10+
# numba/cuda/random.py where it's defined (v<0.57), and by the implementation of the
11+
# same algorithm in the package `randomgen`.
12+
13+
# numba.cuda.random: https://github.com/numba/numba/blob/0.56.3/numba/cuda/random.py
14+
# randomgen: https://github.com/bashtage/randomgen/blob/v1.26.0/randomgen/xoroshiro128.pyx # noqa
15+
16+
# NB1: we implement xoroshiro128++ rather than just xoroshiro128+, which is preferred.
17+
# Reference resource about PRNG: https://prng.di.unimi.it/
18+
19+
# NB2: original numba.cuda.random code also includes functions for generating normally
20+
# distributed floats but we don't include it here as long as it's not needed.
21+
"""
22+
23+
import random
24+
import warnings
25+
from functools import lru_cache
26+
27+
import dpctl
28+
import dpctl.tensor as dpt
29+
import numpy as np
30+
from numba import float32, float64, int64, uint32, uint64
31+
32+
import numba_dpex as dpex
33+
34+
35+
def _get_sequential_processing_device(device: dpctl.SyclDevice):
36+
"""Returns a device most fitted for sequential processing.
37+
38+
Selects a cpu rather than a gpu for sequential processing. If such a cpu
39+
device is not found, returns the input device instead.
40+
41+
Also returns a boolean that informs on wether the returned device is
42+
different than the input device.
43+
"""
44+
if device.has_aspect_cpu:
45+
return device, False
46+
47+
try:
48+
return dpctl.SyclDevice("cpu"), True
49+
except dpctl.SyclDeviceCreationError:
50+
warnings.warn(
51+
"No CPU found, falling back to GPU for sequential instructions."
52+
)
53+
return device, False
54+
55+
56+
zero_idx = int64(0)
57+
one_idx = int64(1)
58+
59+
60+
def get_random_raw(states):
61+
"""Returns a single pseudo-random `uint64` integer value.
62+
63+
Similar to numpy.random.BitGenerator.random_raw(size=1).
64+
65+
Note, this always uses and updates state states[0].
66+
"""
67+
result = dpt.empty((1,), dtype=np.uint64, device=states.device)
68+
dpex.call_kernel(make_random_raw_kernel(), dpex.Range(1), states, result)
69+
return result
70+
71+
72+
@lru_cache
73+
def make_random_raw_kernel():
74+
"""Returns a single pseudo-random `uint64` integer value.
75+
Similar to numpy.random.BitGenerator.random_raw(size=1).
76+
77+
Note, this always uses and updates state states[0].
78+
"""
79+
80+
@dpex.kernel
81+
def _get_random_raw_kernel(
82+
item, states, result
83+
): # pylint: disable=unused-argument
84+
result[zero_idx] = _xoroshiro128pp_next(states, zero_idx)
85+
86+
return _get_random_raw_kernel
87+
88+
89+
def make_rand_uniform_kernel_func(dtype):
90+
"""Instantiate a kernel function that returns a random float in [0, 1)
91+
92+
This factory returns a kernel function specialized to either generate
93+
float32 or float64 values. Care has been taken so that the float32
94+
variant can be compiled to target a device that does not support the
95+
float64 aspect.
96+
97+
The returned kernel function takes two arguments:
98+
99+
- states : a state array. See create_xoroshiro128pp_states for details.
100+
101+
- state_idx : the index of the RNG state to use to generate the next
102+
random float.
103+
104+
"""
105+
if not hasattr(dtype, "name"):
106+
raise ValueError(
107+
"dtype is expected to have an attribute 'name', like np.dtype "
108+
"or numba types."
109+
)
110+
111+
if dtype.name == "float64":
112+
convert_rshift = uint32(11)
113+
convert_const = float64(uint64(1) << uint32(53))
114+
convert_const_one = float64(1)
115+
116+
@dpex.device_func
117+
def uint64_to_unit_float(x):
118+
"""Convert uint64 to float64 value in the range [0.0, 1.0)"""
119+
return float64(x >> convert_rshift) * (
120+
convert_const_one / convert_const
121+
)
122+
123+
elif dtype.name == "float32":
124+
convert_rshift = uint32(40)
125+
convert_const = float32(uint32(1) << uint32(24))
126+
convert_const_one = float32(1)
127+
128+
@dpex.device_func
129+
def uint64_to_unit_float(x):
130+
"""Convert uint64 to float32 value in the range [0.0, 1.0)
131+
132+
NB: this is different than original numba.cuda.random code. Instead
133+
of generating a float64 random number before casting it to float32,
134+
a float32 number is generated from uint64 without intermediate
135+
float64. This change enables compatibility with devices that do not
136+
support float64 numbers. However is seems to be exactly equivalent
137+
e.g it passes the float precision test in sklearn.
138+
"""
139+
return float32(x >> convert_rshift) * (
140+
convert_const_one / convert_const
141+
)
142+
143+
else:
144+
raise ValueError(
145+
"Expected dtype.name in {float32, float64} but got "
146+
f"dtype.name == {dtype.name}"
147+
)
148+
149+
@dpex.device_func
150+
def xoroshiro128pp_uniform(states, state_idx):
151+
"""Return one random float in [0, 1)
152+
153+
Calling this function advances the states[state_idx] by a single RNG
154+
step and leaves the other states unchanged.
155+
"""
156+
return uint64_to_unit_float(_xoroshiro128pp_next(states, state_idx))
157+
158+
return xoroshiro128pp_uniform
159+
160+
161+
def create_xoroshiro128pp_states(
162+
n_states, subsequence_start=0, seed=None, device=None
163+
):
164+
"""Returns a new device array initialized for n random number generators.
165+
166+
This initializes the RNG states so that states in the array correspond to
167+
subsequences separated by 2**64 steps from each other in the main sequence.
168+
Therefore, as long as no thread requests more than 2**64 random numbers, all
169+
the RNG states produced by this function are guaranteed to be independent.
170+
171+
Parameters
172+
----------
173+
n_states : int
174+
Number of RNG states to create. Each RNG state is meant to be used by a
175+
distinct thread in the xoroshiro128pp RNG. Therefore n_states controls
176+
the amount of parallelism when using the RNG to generate a large enough
177+
sequence of pseudo-random values. Subsequent states are initialized
178+
2**64 RNG steps away from one another.
179+
180+
subsequence_start : int
181+
Advance the first RNG state by a multiple of 2**64 steps after the
182+
state induced by the seed. The subsequent RNG states controlled by
183+
`n_states` are each initialized 2**64 steps further from their
184+
predecessor in the states array.
185+
186+
seed : int or None
187+
Starting seed for the list of generators.
188+
189+
device : str or None (default)
190+
A SYCL device or if None, takes the default sycl device.
191+
"""
192+
if seed is None:
193+
seed = uint64(random.randint(0, np.iinfo(np.int64).max - 1))
194+
195+
if hasattr(seed, "randint"):
196+
seed = uint64(seed.randint(0, np.iinfo(np.int64).max - 1))
197+
198+
init_xoroshiro128pp_states_kernel = _make_init_xoroshiro128pp_states_kernel(
199+
n_states, subsequence_start
200+
)
201+
202+
# Initialization is purely sequential so it will be faster on CPU, if a
203+
# cpu device is available make sure to use it.
204+
if device is None:
205+
device = dpctl.SyclDevice()
206+
207+
(
208+
sequential_processing_device,
209+
sequential_processing_on_different_device,
210+
) = _get_sequential_processing_device(device)
211+
212+
states = dpt.empty(
213+
(n_states, 2), dtype=np.uint64, device=sequential_processing_device
214+
)
215+
216+
seed = dpt.asarray(
217+
[seed], dtype=np.uint64, device=sequential_processing_device
218+
)
219+
220+
dpex.call_kernel(
221+
init_xoroshiro128pp_states_kernel, dpex.Range(1), states, seed
222+
)
223+
224+
if sequential_processing_on_different_device:
225+
return states.to_device(device)
226+
227+
return states
228+
229+
230+
@lru_cache
231+
def _make_init_xoroshiro128pp_states_kernel(
232+
n_states, subsequence_start
233+
): # pylint: disable=too-many-locals
234+
n_states = int64(n_states)
235+
236+
splitmix64_const_1 = uint64(0x9E3779B97F4A7C15)
237+
splitmix64_const_2 = uint64(0xBF58476D1CE4E5B9)
238+
splitmix64_const_3 = uint64(0x94D049BB133111EB)
239+
splitmix64_rshift_1 = uint32(30)
240+
splitmix64_rshift_2 = uint32(27)
241+
splitmix64_rshift_3 = uint32(31)
242+
243+
@dpex.device_func
244+
def _splitmix64_next(state):
245+
new_state = z = state + splitmix64_const_1
246+
z = (z ^ (z >> splitmix64_rshift_1)) * splitmix64_const_2
247+
z = (z ^ (z >> splitmix64_rshift_2)) * splitmix64_const_3
248+
return new_state, z ^ (z >> splitmix64_rshift_3)
249+
250+
jump_const_1 = uint64(0x2BD7A6A6E99C2DDC)
251+
jump_const_2 = uint64(0x0992CCAF6A6FCA05)
252+
jump_const_3 = uint64(1)
253+
jump_init = uint64(0)
254+
long_2 = int64(2)
255+
long_64 = int64(64)
256+
257+
@dpex.device_func
258+
def _xoroshiro128pp_jump(states, state_idx):
259+
"""Advance the RNG in ``states[state_idx]`` by 2**64 steps."""
260+
s0 = jump_init
261+
s1 = jump_init
262+
263+
for i in range(long_2):
264+
if i == zero_idx:
265+
jump_const = jump_const_1
266+
else:
267+
jump_const = jump_const_2
268+
for b in range(long_64):
269+
if jump_const & jump_const_3 << uint32(b):
270+
s0 ^= states[state_idx, zero_idx]
271+
s1 ^= states[state_idx, one_idx]
272+
_xoroshiro128pp_next(states, state_idx)
273+
274+
states[state_idx, zero_idx] = s0
275+
states[state_idx, one_idx] = s1
276+
277+
init_const_1 = np.uint64(0)
278+
279+
@dpex.kernel
280+
def init_xoroshiro128pp_states(
281+
item, states, seed
282+
): # pylint: disable=unused-argument
283+
"""
284+
Use SplitMix64 to generate an xoroshiro128p state from a uint64 seed.
285+
286+
This ensures that manually set small seeds don't result in a predictable
287+
initial sequence from the random number generator.
288+
"""
289+
if n_states < one_idx:
290+
return
291+
292+
splitmix64_state = init_const_1 ^ seed[zero_idx]
293+
splitmix64_state, states[zero_idx, zero_idx] = _splitmix64_next(
294+
splitmix64_state
295+
)
296+
_, states[zero_idx, one_idx] = _splitmix64_next(splitmix64_state)
297+
298+
# advance to starting subsequence number
299+
for _ in range(subsequence_start):
300+
_xoroshiro128pp_jump(states, zero_idx)
301+
302+
# populate the rest of the array
303+
for idx in range(one_idx, n_states):
304+
# take state of previous generator
305+
states[idx, zero_idx] = states[idx - one_idx, zero_idx]
306+
states[idx, one_idx] = states[idx - one_idx, one_idx]
307+
# and jump forward 2**64 steps
308+
_xoroshiro128pp_jump(states, idx)
309+
310+
return init_xoroshiro128pp_states
311+
312+
313+
_64_as_uint32 = uint32(64)
314+
315+
316+
@dpex.device_func
317+
def _rotl(x, k):
318+
"""Left rotate x by k bits. x is expected to be a uint64 integer."""
319+
return (x << k) | (x >> (_64_as_uint32 - k))
320+
321+
322+
next_rot_1 = uint32(17)
323+
next_rot_2 = uint32(49)
324+
next_rot_3 = uint32(28)
325+
shift_1 = uint32(21)
326+
327+
328+
@dpex.device_func
329+
def _xoroshiro128pp_next(states, state_idx):
330+
"""
331+
Returns the next random uint64 and advance the RNG in states[state_idx].
332+
"""
333+
s0 = states[state_idx, zero_idx]
334+
s1 = states[state_idx, one_idx]
335+
result = _rotl(s0 + s1, next_rot_1) + s0
336+
337+
s1 ^= s0
338+
states[state_idx, zero_idx] = _rotl(s0, next_rot_2) ^ s1 ^ (s1 << shift_1)
339+
states[state_idx, one_idx] = _rotl(s1, next_rot_3)
340+
341+
return result

0 commit comments

Comments
 (0)