@@ -17,24 +17,8 @@ np.import_array()
17
17
18
18
# IF PCG_EMULATED_MATH==1:
19
19
cdef extern from " src/pcg64/pcg64.h" :
20
-
21
- ctypedef struct pcg128_t:
22
- uint64_t high
23
- uint64_t low
24
- # ELSE:
25
- # cdef extern from "inttypes.h":
26
- # ctypedef unsigned long long __uint128_t
27
- #
28
- # cdef extern from "src/pcg64/pcg64.h":
29
- # ctypedef __uint128_t pcg128_t
30
-
31
- cdef extern from " src/pcg64/pcg64.h" :
32
-
33
- cdef struct pcg_state_setseq_128:
34
- pcg128_t state
35
- pcg128_t inc
36
-
37
- ctypedef pcg_state_setseq_128 pcg64_random_t
20
+ # Use int as generic type, actual type read from pcg64.h and is platform dependent
21
+ ctypedef int pcg64_random_t
38
22
39
23
struct s_pcg64_state:
40
24
pcg64_random_t * pcg_state
@@ -48,6 +32,8 @@ cdef extern from "src/pcg64/pcg64.h":
48
32
void pcg64_jump(pcg64_state * state)
49
33
void pcg64_advance(pcg64_state * state, uint64_t * step)
50
34
void pcg64_set_seed(pcg64_state * state, uint64_t * seed, uint64_t * inc)
35
+ void pcg64_get_state(pcg64_state * state, uint64_t * state_arr, int * has_uint32, uint32_t * uinteger)
36
+ void pcg64_set_state(pcg64_state * state, uint64_t * state_arr, int has_uint32, uint32_t uinteger)
51
37
52
38
cdef uint64_t pcg64_uint64(void * st) nogil:
53
39
return pcg64_next64(< pcg64_state * > st)
@@ -280,40 +266,39 @@ cdef class PCG64:
280
266
Dictionary containing the information required to describe the
281
267
state of the RNG
282
268
"""
283
- # IF PCG_EMULATED_MATH==1:
284
- # TODO: push this into an #ifdef in the C code
285
- state = 2 ** 64 * self .rng_state.pcg_state.state.high
286
- state += self .rng_state.pcg_state.state.low
287
- inc = 2 ** 64 * self .rng_state.pcg_state.inc.high
288
- inc += self .rng_state.pcg_state.inc.low
289
- # ELSE:
290
- # state = self.rng_state.pcg_state.state
291
- # inc = self.rng_state.pcg_state.inc
292
-
269
+ cdef np.ndarray state_vec
270
+ cdef int has_uint32
271
+ cdef uint32_t uinteger
272
+
273
+ # state_vec is state.high, state.low, inc.high, inc.low
274
+ state_vec = < np.ndarray> np.empty(4 , dtype = np.uint64)
275
+ pcg64_get_state(self .rng_state, < uint64_t * > state_vec.data, & has_uint32, & uinteger)
276
+ state = int (state_vec[0 ]) * 2 ** 64 + int (state_vec[1 ])
277
+ inc = int (state_vec[2 ]) * 2 ** 64 + int (state_vec[3 ])
293
278
return {' brng' : self .__class__.__name__ ,
294
279
' state' : {' state' : state, ' inc' : inc},
295
- ' has_uint32' : self .rng_state. has_uint32,
296
- ' uinteger' : self .rng_state. uinteger}
280
+ ' has_uint32' : has_uint32,
281
+ ' uinteger' : uinteger}
297
282
298
283
@state.setter
299
284
def state (self , value ):
285
+ cdef np.ndarray state_vec
286
+ cdef int has_uint32
287
+ cdef uint32_t uinteger
300
288
if not isinstance (value, dict ):
301
289
raise TypeError (' state must be a dict' )
302
290
brng = value.get(' brng' , ' ' )
303
291
if brng != self .__class__.__name__ :
304
292
raise ValueError (' state must be for a {0} '
305
293
' RNG' .format(self .__class__.__name__ ))
306
- # IF PCG_EMULATED_MATH==1:
307
- self .rng_state.pcg_state.state.high = value[' state' ][' state' ] // 2 ** 64
308
- self .rng_state.pcg_state.state.low = value[' state' ][' state' ] % 2 ** 64
309
- self .rng_state.pcg_state.inc.high = value[' state' ][' inc' ] // 2 ** 64
310
- self .rng_state.pcg_state.inc.low = value[' state' ][' inc' ] % 2 ** 64
311
- # ELSE:
312
- # self.rng_state.pcg_state.state = value['state']['state']
313
- # self.rng_state.pcg_state.inc = value['state']['inc']
314
-
315
- self .rng_state.has_uint32 = value[' has_uint32' ]
316
- self .rng_state.uinteger = value[' uinteger' ]
294
+ state_vec = < np.ndarray> np.empty(4 , dtype = np.uint64)
295
+ state_vec[0 ] = value[' state' ][' state' ] // 2 ** 64
296
+ state_vec[1 ] = value[' state' ][' state' ] % 2 ** 64
297
+ state_vec[2 ] = value[' state' ][' inc' ] // 2 ** 64
298
+ state_vec[3 ] = value[' state' ][' inc' ] % 2 ** 64
299
+ has_uint32 = value[' has_uint32' ]
300
+ uinteger = value[' uinteger' ]
301
+ pcg64_set_state(self .rng_state, < uint64_t * > state_vec.data, has_uint32, uinteger)
317
302
318
303
def advance (self , delta ):
319
304
"""
0 commit comments