Skip to content

Commit 60682de

Browse files
committed
support CAI too
1 parent 94ec937 commit 60682de

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

cuda_py/cuda/py/_dlpack.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) noexcept with gil:
4141
stdlib.free(tensor)
4242

4343

44-
cpdef object make_py_capsule(object buf, bint versioned) except +:
44+
cpdef object make_py_capsule(object buf, bint versioned):
4545
cdef DLManagedTensor* dlm_tensor
4646
cdef DLManagedTensorVersioned* dlm_tensor_ver
4747
cdef DLTensor* dl_tensor

cuda_py/cuda/py/_memoryview.pyx

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@ from ._dlpack cimport *
99
import functools
1010
from typing import Any, Optional
1111

12+
from cuda import cuda
1213
import numpy
1314

15+
from cuda.py._utils import handle_return
16+
17+
18+
# TODO(leofang): support NumPy structured dtypes
19+
1420

1521
@cython.dataclasses.dataclass
1622
cdef class GPUMemoryView:
@@ -37,6 +43,7 @@ cdef class GPUMemoryView:
3743

3844

3945
cdef str get_simple_repr(obj):
46+
# TODO: better handling in np.dtype objects
4047
cdef object obj_class
4148
cdef str obj_repr
4249
if isinstance(obj, type):
@@ -71,8 +78,7 @@ cdef class _GPUMemoryViewProxy:
7178
if self.has_dlpack:
7279
return view_as_dlpack(self.obj, stream_ptr)
7380
else:
74-
# TODO: Support CAI
75-
raise NotImplementedError("TODO")
81+
return view_as_cai(self.obj, stream_ptr)
7682

7783

7884
cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
@@ -216,7 +222,49 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
216222
else:
217223
raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code))
218224

219-
return np_dtype
225+
# We want the dtype object not just the type object
226+
return numpy.dtype(np_dtype)
227+
228+
229+
cdef GPUMemoryView view_as_cai(obj, stream_ptr):
230+
cdef dict cai_data = obj.__cuda_array_interface__
231+
if cai_data["version"] < 3:
232+
raise BufferError("only CUDA Array Interface v3 or above is supported")
233+
if cai_data.get("mask") is not None:
234+
raise BufferError("mask is not supported")
235+
if stream_ptr is None:
236+
raise BufferError("stream=None is ambiguous with view()")
237+
238+
cdef GPUMemoryView buf = GPUMemoryView()
239+
buf.obj = obj
240+
buf.ptr, buf.readonly = cai_data["data"]
241+
buf.shape = cai_data["shape"]
242+
# TODO: this only works for built-in numeric types
243+
buf.dtype = numpy.dtype(cai_data["typestr"])
244+
buf.strides = cai_data.get("strides")
245+
if buf.strides is not None:
246+
# convert to counts
247+
buf.strides = tuple(s // buf.dtype.itemsize for s in buf.strides)
248+
buf.device_accessible = True
249+
buf.device_id = handle_return(
250+
cuda.cuPointerGetAttribute(
251+
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
252+
buf.ptr))
253+
254+
cdef intptr_t producer_s, consumer_s
255+
stream = cai_data.get("stream")
256+
if stream is not None:
257+
producer_s = <intptr_t>(stream)
258+
consumer_s = <intptr_t>(stream_ptr)
259+
assert producer_s > 0
260+
# establish stream order
261+
if producer_s != consumer_s:
262+
e = handle_return(cuda.cuEventCreate(
263+
cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING))
264+
handle_return(cuda.cuEventRecord(e, producer_s))
265+
handle_return(cuda.cuStreamWaitEvent(consumer_s, e, 0))
266+
267+
return buf
220268

221269

222270
def viewable(tuple arg_indices):

0 commit comments

Comments
 (0)