diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index c82a27807c..a6d7eab975 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -127,9 +127,11 @@ def test_from_dlpack(shape, typestr, usm_type): Y = dpt.from_dlpack(X) assert X.shape == Y.shape assert X.dtype == Y.dtype - assert X.sycl_device == Y.sycl_device assert X.usm_type == Y.usm_type assert X._pointer == Y._pointer + # we can only expect device to round-trip for USM-device and + # USM-shared allocations, which are made for specific device + assert (Y.usm_type == "host") or (X.sycl_device == Y.sycl_device) if Y.ndim: V = Y[::-1] W = dpt.from_dlpack(V) @@ -149,9 +151,11 @@ def test_from_dlpack_strides(mod, typestr, usm_type): Y = dpt.from_dlpack(X) assert X.shape == Y.shape assert X.dtype == Y.dtype - assert X.sycl_device == Y.sycl_device assert X.usm_type == Y.usm_type assert X._pointer == Y._pointer + # we can only expect device to round-trip for USM-device and + # USM-shared allocations, which are made for specific device + assert (Y.usm_type == "host") or (X.sycl_device == Y.sycl_device) if Y.ndim: V = Y[::-1] W = dpt.from_dlpack(V)