Skip to content

Add dlpack support #1296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __bool__(self):
return self._array_obj.__bool__()

# '__class__',

def __complex__(self):
return self._array_obj.__complex__()

Expand All @@ -153,6 +153,12 @@ def __complex__(self):
# '__divmod__',
# '__doc__',

def __dlpack__(self, stream=None):
return self._array_obj.__dlpack__(stream=stream)

def __dlpack_device__(self):
return self._array_obj.__dlpack_device__()

def __eq__(self, other):
return dpnp.equal(self, other)

Expand Down Expand Up @@ -190,7 +196,7 @@ def __gt__(self, other):
# '__imatmul__',
# '__imod__',
# '__imul__',

def __index__(self):
return self._array_obj.__index__()

Expand Down Expand Up @@ -313,6 +319,16 @@ def __truediv__(self, other):

# '__xor__',

@staticmethod
def _create_from_usm_ndarray(usm_ary : dpt.usm_ndarray):
if not isinstance(usm_ary, dpt.usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(usm_ary)}"
)
res = dpnp_array.__new__(dpnp_array)
res._array_obj = usm_ary
return res

def all(self, axis=None, out=None, keepdims=False):
"""
Returns True if all elements evaluate to True.
Expand Down
26 changes: 26 additions & 0 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"default_float_type",
"dpnp_queue_initialize",
"dpnp_queue_is_cpu",
"from_dlpack",
"get_dpnp_descriptor",
"get_include",
"get_normalized_queue_device"
Expand Down Expand Up @@ -222,6 +223,31 @@ def default_float_type(device=None, sycl_queue=None):
return map_dtype_to_device(float64, _sycl_queue.sycl_device)


def from_dlpack(obj, /):
"""
Create a dpnp array from a Python object implementing the ``__dlpack__``
protocol.

See https://dmlc.github.io/dlpack/latest/ for more details.

Parameters
----------
obj : object
A Python object representing an array that implements the ``__dlpack__``
and ``__dlpack_device__`` methods.

Returns
-------
out : dpnp_array
Returns a new dpnp array containing the data from another array
(obj) with the ``__dlpack__`` method on the same device as object.

"""

usm_ary = dpt.from_dlpack(obj)
return dpnp_array._create_from_usm_ndarray(usm_ary)


def get_dpnp_descriptor(ext_obj,
copy_when_strides=True,
copy_when_nondefault_queue=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_all_dtypes(no_bool=False,
dtypes.append(dpnp.complex64)
if dev.has_aspect_fp64:
dtypes.append(dpnp.complex128)

# add None value to validate a default dtype
if not no_none:
dtypes.append(None)
Expand Down
45 changes: 43 additions & 2 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import pytest
from .helper import get_all_dtypes

import dpnp
import dpctl
import numpy

from numpy.testing import (
assert_array_equal
)


list_of_backend_str = [
"host",
Expand Down Expand Up @@ -155,7 +160,7 @@ def test_array_creation_like(func, kwargs, device_x, device_y):

dpnp_kwargs = dict(kwargs)
dpnp_kwargs['device'] = device_y

y = getattr(dpnp, func)(x, **dpnp_kwargs)
numpy.testing.assert_array_equal(y_orig, y)
assert_sycl_queue_equal(y.sycl_queue, x.to_device(device_y).sycl_queue)
Expand Down Expand Up @@ -647,7 +652,7 @@ def test_eig(device):
dpnp_val_queue = dpnp_val.get_array().sycl_queue
dpnp_vec_queue = dpnp_vec.get_array().sycl_queue

# compare queue and device
# compare queue and device
assert_sycl_queue_equal(dpnp_val_queue, expected_queue)
assert_sycl_queue_equal(dpnp_vec_queue, expected_queue)

Expand Down Expand Up @@ -816,3 +821,39 @@ def test_array_copy(device, func, device_param, queue_param):
result = dpnp.array(dpnp_data, **kwargs)

assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue)


@pytest.mark.parametrize("device",
valid_devices,
ids=[device.filter_string for device in valid_devices])
#TODO need to delete no_bool=True when use dlpack > 0.7 version
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True))
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
def test_from_dlpack(arr_dtype, shape, device):
X = dpnp.empty(shape=shape, dtype=arr_dtype, device=device)
Y = dpnp.from_dlpack(X)
assert_array_equal(X, Y)
assert X.__dlpack_device__() == Y.__dlpack_device__()
assert X.sycl_device == Y.sycl_device
assert X.sycl_context == Y.sycl_context
assert X.usm_type == Y.usm_type
if Y.ndim:
V = Y[::-1]
W = dpnp.from_dlpack(V)
assert V.strides == W.strides


@pytest.mark.parametrize("device",
valid_devices,
ids=[device.filter_string for device in valid_devices])
#TODO need to delete no_bool=True when use dlpack > 0.7 version
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True, no_bool=True))
def test_from_dlpack_with_dpt(arr_dtype, device):
X = dpctl.tensor.empty((64,), dtype=arr_dtype, device=device)
Y = dpnp.from_dlpack(X)
assert_array_equal(X, Y)
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)
assert X.__dlpack_device__() == Y.__dlpack_device__()
assert X.sycl_device == Y.sycl_device
assert X.sycl_context == Y.sycl_context
assert X.usm_type == Y.usm_type