Skip to content

Commit 85465c2

Browse files
committed
Fixing bugs in random and random tests
1 parent 4b3369f commit 85465c2

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
from .index import *
7373
from .interop import *
7474
from .timer import *
75+
from .random import *
7576

7677
# do not export default modules as part of arrayfire
7778
del ct

arrayfire/random.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_seed(self):
8181
safe_call(backend.get().af_random_engine_get_seed(ct.pointer(seed), self.engine))
8282
return seed.value
8383

84-
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
84+
def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None):
8585
"""
8686
Create a multi dimensional array containing values from a uniform distribution.
8787
@@ -102,8 +102,8 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
102102
dtype : optional: af.Dtype. default: af.Dtype.f32.
103103
Data type of the array.
104104
105-
random_engine : optional: Random_Engine. default: None.
106-
If random_engine is None, uses a default engine created by arrayfire.
105+
engine : optional: Random_Engine. default: None.
106+
If engine is None, uses a default engine created by arrayfire.
107107
108108
Returns
109109
-------
@@ -118,14 +118,14 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
118118
out = Array()
119119
dims = dim4(d0, d1, d2, d3)
120120

121-
if random_engine is None:
121+
if engine is None:
122122
safe_call(backend.get().af_randu(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
123123
else:
124-
safe_call(backend.get().af_random_uniform(ct.pointer(out.arr), 4, ct.pointer(dims), random_engine.engine))
124+
safe_call(backend.get().af_random_uniform(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value, engine.engine))
125125

126126
return out
127127

128-
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
128+
def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, engine=None):
129129
"""
130130
Create a multi dimensional array containing values from a normal distribution.
131131
@@ -146,8 +146,8 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
146146
dtype : optional: af.Dtype. default: af.Dtype.f32.
147147
Data type of the array.
148148
149-
random_engine : optional: Random_Engine. default: None.
150-
If random_engine is None, uses a default engine created by arrayfire.
149+
engine : optional: Random_Engine. default: None.
150+
If engine is None, uses a default engine created by arrayfire.
151151
152152
Returns
153153
-------
@@ -163,10 +163,10 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=Dtype.f32, random_engine=None):
163163
out = Array()
164164
dims = dim4(d0, d1, d2, d3)
165165

166-
if random_engine is None:
166+
if engine is None:
167167
safe_call(backend.get().af_randn(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value))
168168
else:
169-
safe_call(backend.get().af_random_normal(ct.pointer(out.arr), 4, ct.pointer(dims), random_engine.engine))
169+
safe_call(backend.get().af_random_normal(ct.pointer(out.arr), 4, ct.pointer(dims), dtype.value, engine.engine))
170170

171171
return out
172172

arrayfire/tests/simple/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from .lapack import *
1919
from .signal import *
2020
from .statistics import *
21+
from .random import *
2122
from ._util import tests

arrayfire/tests/simple/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def simple_random(verbose=False):
2525
af.set_seed(1024)
2626
assert(af.get_seed() == 1024)
2727

28-
engine = Random_Engine(RANDOM_ENGINE.MERSENNE_GP11213, 100)
28+
engine = af.Random_Engine(af.RANDOM_ENGINE.MERSENNE_GP11213, 100)
2929

3030
display_func(af.randu(3, 3, 1, 2, engine=engine))
3131
display_func(af.randu(3, 3, 1, 2, af.Dtype.s32, engine=engine))

0 commit comments

Comments
 (0)