Skip to content

Commit b5c3ee7

Browse files
committed
Adds tests for inplace division behavior
1 parent eab67fb commit b5c3ee7

File tree

2 files changed

+137
-6
lines changed

2 files changed

+137
-6
lines changed

dpctl/tests/elementwise/test_divide.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._type_utils import _can_cast
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

26-
from .utils import _all_dtypes, _compare_dtypes, _usm_types
27+
from .utils import (
28+
_all_dtypes,
29+
_compare_dtypes,
30+
_complex_fp_dtypes,
31+
_real_fp_dtypes,
32+
_usm_types,
33+
)
2734

2835

2936
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
@@ -187,3 +194,65 @@ def __sycl_usm_array_interface__(self):
187194
c = Canary()
188195
with pytest.raises(ValueError):
189196
dpt.divide(a, c)
197+
198+
199+
@pytest.mark.parametrize("dtype", _real_fp_dtypes + _complex_fp_dtypes)
200+
def test_divide_inplace_python_scalar(dtype):
201+
q = get_queue_or_skip()
202+
skip_if_dtype_not_supported(dtype, q)
203+
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
204+
dt_kind = X.dtype.kind
205+
if dt_kind == "f":
206+
X /= float(1)
207+
elif dt_kind == "c":
208+
X /= complex(1)
209+
210+
211+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
212+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
213+
def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
214+
q = get_queue_or_skip()
215+
skip_if_dtype_not_supported(op1_dtype, q)
216+
skip_if_dtype_not_supported(op2_dtype, q)
217+
218+
sz = 127
219+
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
220+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
221+
222+
dev = q.sycl_device
223+
_fp16 = dev.has_aspect_fp16
224+
_fp64 = dev.has_aspect_fp64
225+
# out array only valid if it is inexact
226+
if (
227+
_can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64)
228+
and dpt.dtype(op1_dtype).kind in "fc"
229+
):
230+
ar1 /= ar2
231+
assert dpt.all(ar1 == 1)
232+
233+
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
234+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
235+
ar3 /= ar4
236+
assert dpt.all(ar3 == 1)
237+
else:
238+
with pytest.raises(TypeError):
239+
ar1 /= ar2
240+
dpt.divide(ar1, ar2, out=ar1)
241+
242+
# out is second arg
243+
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
244+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
245+
if (
246+
_can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64)
247+
and dpt.dtype(op2_dtype).kind in "fc"
248+
):
249+
dpt.divide(ar1, ar2, out=ar2)
250+
assert dpt.all(ar2 == 1)
251+
252+
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
253+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
254+
dpt.divide(ar3, ar4, out=ar4)
255+
dpt.all(ar4 == 1)
256+
else:
257+
with pytest.raises(TypeError):
258+
dpt.divide(ar1, ar2, out=ar2)

dpctl/tests/elementwise/test_floor_divide.py

+67-5
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,19 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._type_utils import _can_cast
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

26-
from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types
27+
from .utils import (
28+
_compare_dtypes,
29+
_integral_dtypes,
30+
_no_complex_dtypes,
31+
_usm_types,
32+
)
2733

2834

29-
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes)
30-
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes)
35+
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
36+
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
3137
def test_floor_divide_dtype_matrix(op1_dtype, op2_dtype):
3238
q = get_queue_or_skip()
3339
skip_if_dtype_not_supported(op1_dtype, q)
@@ -133,7 +139,7 @@ def test_floor_divide_broadcasting():
133139
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()
134140

135141

136-
@pytest.mark.parametrize("arr_dt", _no_complex_dtypes)
142+
@pytest.mark.parametrize("arr_dt", _no_complex_dtypes[1:])
137143
def test_floor_divide_python_scalar(arr_dt):
138144
q = get_queue_or_skip()
139145
skip_if_dtype_not_supported(arr_dt, q)
@@ -204,7 +210,7 @@ def test_floor_divide_gh_1247():
204210
)
205211

206212

207-
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
213+
@pytest.mark.parametrize("dtype", _integral_dtypes)
208214
def test_floor_divide_integer_zero(dtype):
209215
q = get_queue_or_skip()
210216
skip_if_dtype_not_supported(dtype, q)
@@ -255,3 +261,59 @@ def test_floor_divide_special_cases():
255261
res = dpt.floor_divide(x, y)
256262
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
257263
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
264+
265+
266+
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:])
267+
def test_divide_inplace_python_scalar(dtype):
268+
q = get_queue_or_skip()
269+
skip_if_dtype_not_supported(dtype, q)
270+
X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q)
271+
dt_kind = X.dtype.kind
272+
if dt_kind in "ui":
273+
X //= int(1)
274+
elif dt_kind == "f":
275+
X //= float(1)
276+
277+
278+
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
279+
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
280+
def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype):
281+
q = get_queue_or_skip()
282+
skip_if_dtype_not_supported(op1_dtype, q)
283+
skip_if_dtype_not_supported(op2_dtype, q)
284+
285+
sz = 127
286+
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
287+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
288+
289+
dev = q.sycl_device
290+
_fp16 = dev.has_aspect_fp16
291+
_fp64 = dev.has_aspect_fp64
292+
# out array only valid if it is inexact
293+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
294+
ar1 //= ar2
295+
assert dpt.all(ar1 == 1)
296+
297+
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
298+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
299+
ar3 //= ar4
300+
assert dpt.all(ar3 == 1)
301+
else:
302+
with pytest.raises(TypeError):
303+
ar1 //= ar2
304+
dpt.floor_divide(ar1, ar2, out=ar1)
305+
306+
# out is second arg
307+
ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)
308+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q)
309+
if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64):
310+
dpt.floor_divide(ar1, ar2, out=ar2)
311+
assert dpt.all(ar2 == 1)
312+
313+
ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1]
314+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2]
315+
dpt.floor_divide(ar3, ar4, out=ar4)
316+
dpt.all(ar4 == 1)
317+
else:
318+
with pytest.raises(TypeError):
319+
dpt.floor_divide(ar1, ar2, out=ar2)

0 commit comments

Comments
 (0)