|
1 | 1 | import pytest
|
| 2 | +from .helper import get_all_dtypes |
2 | 3 |
|
3 | 4 | import dpnp
|
4 | 5 | import dpctl
|
5 | 6 | import numpy
|
6 | 7 |
|
| 8 | +from numpy.testing import ( |
| 9 | + assert_array_equal |
| 10 | +) |
| 11 | + |
7 | 12 |
|
8 | 13 | list_of_backend_str = [
|
9 | 14 | "host",
|
@@ -155,7 +160,7 @@ def test_array_creation_like(func, kwargs, device_x, device_y):
|
155 | 160 |
|
156 | 161 | dpnp_kwargs = dict(kwargs)
|
157 | 162 | dpnp_kwargs['device'] = device_y
|
158 |
| - |
| 163 | + |
159 | 164 | y = getattr(dpnp, func)(x, **dpnp_kwargs)
|
160 | 165 | numpy.testing.assert_array_equal(y_orig, y)
|
161 | 166 | assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
|
@@ -637,7 +642,7 @@ def test_eig(device):
|
637 | 642 | dpnp_val_queue = dpnp_val.get_array().sycl_queue
|
638 | 643 | dpnp_vec_queue = dpnp_vec.get_array().sycl_queue
|
639 | 644 |
|
640 |
| - # compare queue and device |
| 645 | + # compare queue and device |
641 | 646 | assert_sycl_queue_equal(dpnp_val_queue, expected_queue)
|
642 | 647 | assert_sycl_queue_equal(dpnp_vec_queue, expected_queue)
|
643 | 648 |
|
@@ -806,3 +811,37 @@ def test_array_copy(device, func, device_param, queue_param):
|
806 | 811 | result = dpnp.array(dpnp_data, **kwargs)
|
807 | 812 |
|
808 | 813 | assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue)
|
| 814 | + |
| 815 | + |
| 816 | +@pytest.mark.parametrize("device", |
| 817 | + valid_devices, |
| 818 | + ids=[device.filter_string for device in valid_devices]) |
| 819 | +#TODO need to delete no_bool=True when use dlpack > 0.7 version |
| 820 | +@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True)) |
| 821 | +@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)]) |
| 822 | +def test_from_dlpack(arr_dtype, shape, device): |
| 823 | + X = dpnp.empty(shape=shape, dtype=arr_dtype, device=device) |
| 824 | + Y = dpnp.from_dlpack(X) |
| 825 | + assert_array_equal(X, Y) |
| 826 | + assert X.__dlpack_device__() == Y.__dlpack_device__() |
| 827 | + assert X.sycl_device == Y.sycl_device |
| 828 | + assert X.usm_type == Y.usm_type |
| 829 | + if Y.ndim: |
| 830 | + V = Y[::-1] |
| 831 | + W = dpnp.from_dlpack(V) |
| 832 | + assert V.strides == W.strides |
| 833 | + |
| 834 | + |
| 835 | +@pytest.mark.parametrize("device", |
| 836 | + valid_devices, |
| 837 | + ids=[device.filter_string for device in valid_devices]) |
| 838 | +#TODO need to delete no_bool=True when use dlpack > 0.7 version |
| 839 | +@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True)) |
| 840 | +def test_from_dlpack_with_dpt(arr_dtype, device): |
| 841 | + X = dpctl.tensor.empty((64,), dtype=arr_dtype, device=device) |
| 842 | + Y = dpnp.from_dlpack(X) |
| 843 | + assert_array_equal(X, Y) |
| 844 | + assert isinstance(Y, dpnp.dpnp_array.dpnp_array) |
| 845 | + assert X.__dlpack_device__() == Y.__dlpack_device__() |
| 846 | + assert X.sycl_device == Y.sycl_device |
| 847 | + assert X.usm_type == Y.usm_type |
0 commit comments