|
21 | 21 |
|
22 | 22 | import dpctl
|
23 | 23 | import dpctl.tensor as dpt
|
| 24 | +from dpctl.tensor._type_utils import _can_cast |
24 | 25 | from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
|
25 | 26 |
|
26 |
| -from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types |
| 27 | +from .utils import ( |
| 28 | + _compare_dtypes, |
| 29 | + _integral_dtypes, |
| 30 | + _no_complex_dtypes, |
| 31 | + _usm_types, |
| 32 | +) |
27 | 33 |
|
28 | 34 |
|
29 |
| -@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes) |
30 |
| -@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes) |
| 35 | +@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:]) |
| 36 | +@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:]) |
31 | 37 | def test_floor_divide_dtype_matrix(op1_dtype, op2_dtype):
|
32 | 38 | q = get_queue_or_skip()
|
33 | 39 | skip_if_dtype_not_supported(op1_dtype, q)
|
@@ -133,7 +139,7 @@ def test_floor_divide_broadcasting():
|
133 | 139 | assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()
|
134 | 140 |
|
135 | 141 |
|
136 |
| -@pytest.mark.parametrize("arr_dt", _no_complex_dtypes) |
| 142 | +@pytest.mark.parametrize("arr_dt", _no_complex_dtypes[1:]) |
137 | 143 | def test_floor_divide_python_scalar(arr_dt):
|
138 | 144 | q = get_queue_or_skip()
|
139 | 145 | skip_if_dtype_not_supported(arr_dt, q)
|
@@ -204,7 +210,7 @@ def test_floor_divide_gh_1247():
|
204 | 210 | )
|
205 | 211 |
|
206 | 212 |
|
207 |
| -@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9]) |
| 213 | +@pytest.mark.parametrize("dtype", _integral_dtypes) |
208 | 214 | def test_floor_divide_integer_zero(dtype):
|
209 | 215 | q = get_queue_or_skip()
|
210 | 216 | skip_if_dtype_not_supported(dtype, q)
|
@@ -255,3 +261,59 @@ def test_floor_divide_special_cases():
|
255 | 261 | res = dpt.floor_divide(x, y)
|
256 | 262 | res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
|
257 | 263 | np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
|
| 264 | + |
| 265 | + |
| 266 | +@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:]) |
| 267 | +def test_divide_inplace_python_scalar(dtype): |
| 268 | + q = get_queue_or_skip() |
| 269 | + skip_if_dtype_not_supported(dtype, q) |
| 270 | + X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q) |
| 271 | + dt_kind = X.dtype.kind |
| 272 | + if dt_kind in "ui": |
| 273 | + X //= int(1) |
| 274 | + elif dt_kind == "f": |
| 275 | + X //= float(1) |
| 276 | + |
| 277 | + |
| 278 | +@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:]) |
| 279 | +@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:]) |
| 280 | +def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype): |
| 281 | + q = get_queue_or_skip() |
| 282 | + skip_if_dtype_not_supported(op1_dtype, q) |
| 283 | + skip_if_dtype_not_supported(op2_dtype, q) |
| 284 | + |
| 285 | + sz = 127 |
| 286 | + ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q) |
| 287 | + ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q) |
| 288 | + |
| 289 | + dev = q.sycl_device |
| 290 | + _fp16 = dev.has_aspect_fp16 |
| 291 | + _fp64 = dev.has_aspect_fp64 |
| 292 | + # out array only valid if it is inexact |
| 293 | + if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64): |
| 294 | + ar1 //= ar2 |
| 295 | + assert dpt.all(ar1 == 1) |
| 296 | + |
| 297 | + ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1] |
| 298 | + ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2] |
| 299 | + ar3 //= ar4 |
| 300 | + assert dpt.all(ar3 == 1) |
| 301 | + else: |
| 302 | + with pytest.raises(TypeError): |
| 303 | + ar1 //= ar2 |
| 304 | + dpt.floor_divide(ar1, ar2, out=ar1) |
| 305 | + |
| 306 | + # out is second arg |
| 307 | + ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q) |
| 308 | + ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q) |
| 309 | + if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64): |
| 310 | + dpt.floor_divide(ar1, ar2, out=ar2) |
| 311 | + assert dpt.all(ar2 == 1) |
| 312 | + |
| 313 | + ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1] |
| 314 | + ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2] |
| 315 | + dpt.floor_divide(ar3, ar4, out=ar4) |
| 316 | + dpt.all(ar4 == 1) |
| 317 | + else: |
| 318 | + with pytest.raises(TypeError): |
| 319 | + dpt.floor_divide(ar1, ar2, out=ar2) |
0 commit comments