@@ -9,8 +9,14 @@ from ._dlpack cimport *
9
9
import functools
10
10
from typing import Any, Optional
11
11
12
+ from cuda import cuda
12
13
import numpy
13
14
15
+ from cuda.py._utils import handle_return
16
+
17
+
18
+ # TODO(leofang): support NumPy structured dtypes
19
+
14
20
15
21
@cython.dataclasses.dataclass
16
22
cdef class GPUMemoryView:
@@ -37,6 +43,7 @@ cdef class GPUMemoryView:
37
43
38
44
39
45
cdef str get_simple_repr(obj):
46
+ # TODO: better handling in np.dtype objects
40
47
cdef object obj_class
41
48
cdef str obj_repr
42
49
if isinstance (obj, type ):
@@ -71,8 +78,7 @@ cdef class _GPUMemoryViewProxy:
71
78
if self .has_dlpack:
72
79
return view_as_dlpack(self .obj, stream_ptr)
73
80
else :
74
- # TODO: Support CAI
75
- raise NotImplementedError (" TODO" )
81
+ return view_as_cai(self .obj, stream_ptr)
76
82
77
83
78
84
cdef GPUMemoryView view_as_dlpack(obj, stream_ptr):
@@ -216,7 +222,49 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
216
222
else :
217
223
raise TypeError (' Unsupported dtype. dtype code: {}' .format(dtype.code))
218
224
219
- return np_dtype
225
+ # We want the dtype object not just the type object
226
+ return numpy.dtype(np_dtype)
227
+
228
+
229
+ cdef GPUMemoryView view_as_cai(obj, stream_ptr):
230
+ cdef dict cai_data = obj.__cuda_array_interface__
231
+ if cai_data[" version" ] < 3 :
232
+ raise BufferError(" only CUDA Array Interface v3 or above is supported" )
233
+ if cai_data.get(" mask" ) is not None :
234
+ raise BufferError(" mask is not supported" )
235
+ if stream_ptr is None :
236
+ raise BufferError(" stream=None is ambiguous with view()" )
237
+
238
+ cdef GPUMemoryView buf = GPUMemoryView()
239
+ buf.obj = obj
240
+ buf.ptr, buf.readonly = cai_data[" data" ]
241
+ buf.shape = cai_data[" shape" ]
242
+ # TODO: this only works for built-in numeric types
243
+ buf.dtype = numpy.dtype(cai_data[" typestr" ])
244
+ buf.strides = cai_data.get(" strides" )
245
+ if buf.strides is not None :
246
+ # convert to counts
247
+ buf.strides = tuple (s // buf.dtype.itemsize for s in buf.strides)
248
+ buf.device_accessible = True
249
+ buf.device_id = handle_return(
250
+ cuda.cuPointerGetAttribute(
251
+ cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
252
+ buf.ptr))
253
+
254
+ cdef intptr_t producer_s, consumer_s
255
+ stream = cai_data.get(" stream" )
256
+ if stream is not None :
257
+ producer_s = < intptr_t> (stream)
258
+ consumer_s = < intptr_t> (stream_ptr)
259
+ assert producer_s > 0
260
+ # establish stream order
261
+ if producer_s != consumer_s:
262
+ e = handle_return(cuda.cuEventCreate(
263
+ cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING))
264
+ handle_return(cuda.cuEventRecord(e, producer_s))
265
+ handle_return(cuda.cuStreamWaitEvent(consumer_s, e, 0 ))
266
+
267
+ return buf
220
268
221
269
222
270
def viewable (tuple arg_indices ):
0 commit comments