@@ -20,14 +20,67 @@ from cuda.core.experimental._utils import handle_return
20
20
21
21
@cython.dataclasses.dataclass
22
22
cdef class StridedMemoryView:
23
-
23
+ """ A dataclass holding metadata of a strided dense array/tensor.
24
+
25
+ A :obj:`StridedMemoryView` instance can be created in two ways:
26
+
27
+ 1. Using the :obj:`args_viewable_as_strided_memory` decorator (recommended)
28
+ 2. Explicit construction, see below
29
+
30
+ This object supports both DLPack (up to v1.0) and CUDA Array Interface
31
+ (CAI) v3. When wrapping an arbitrary object it will try the DLPack protocol
32
+ first, then the CAI protocol. A :obj:`BufferError` is raised if neither is
33
+ supported.
34
+
35
+ Since either way would take a consumer stream, for DLPack it is passed to
36
+ ``obj.__dlpack__()`` as-is (except for :obj:`None`, see below); for CAI, a
37
+ stream order will be established between the consumer stream and the
38
+ producer stream (from ``obj.__cuda_array_interface__()["stream"]``), as if
39
+ ``cudaStreamWaitEvent`` is called by this method.
40
+
41
+ To opt-out of the stream ordering operation in either DLPack or CAI,
42
+ please pass ``stream_ptr=-1``. Note that this deviates (on purpose)
43
+ from the semantics of ``obj.__dlpack__(stream=None, ...)`` since ``cuda.core``
44
+ does not encourage using the (legacy) default/null stream, but is
45
+ consistent with the CAI's semantics. For DLPack, ``stream=-1`` will be
46
+ internally passed to ``obj.__dlpack__()`` instead.
47
+
48
+ Attributes
49
+ ----------
50
+ ptr : int
51
+ Pointer to the tensor buffer (as a Python `int`).
52
+ shape: tuple
53
+ Shape of the tensor.
54
+ strides: tuple
55
+ Strides of the tensor (in **counts**, not bytes).
56
+ dtype: numpy.dtype
57
+ Data type of the tensor.
58
+ device_id: int
59
+ The device ID for where the tensor is located. It is -1 for CPU tensors
60
+ (meaning those only accessible from the host).
61
+ is_device_accessible: bool
62
+ Whether the tensor data can be accessed on the GPU.
63
+ readonly: bool
64
+ Whether the tensor data can be modified in place.
65
+ exporting_obj: Any
66
+ A reference to the original tensor object that is being viewed.
67
+
68
+ Parameters
69
+ ----------
70
+ obj : Any
71
+ Any objects that supports either DLPack (up to v1.0) or CUDA Array
72
+ Interface (v3).
73
+ stream_ptr: int
74
+ The pointer address (as Python `int`) to the **consumer** stream.
75
+ Stream ordering will be properly established unless ``-1`` is passed.
76
+ """
24
77
# TODO: switch to use Cython's cdef typing?
25
78
ptr: int = None
26
79
shape: tuple = None
27
80
strides: tuple = None # in counts, not bytes
28
81
dtype: numpy.dtype = None
29
82
device_id: int = None # -1 for CPU
30
- device_accessible : bool = None
83
+ is_device_accessible : bool = None
31
84
readonly: bool = None
32
85
exporting_obj: Any = None
33
86
@@ -48,7 +101,7 @@ cdef class StridedMemoryView:
48
101
+ f" strides={self.strides},\n "
49
102
+ f" dtype={get_simple_repr(self.dtype)},\n "
50
103
+ f" device_id={self.device_id},\n "
51
- + f" device_accessible ={self.device_accessible },\n "
104
+ + f" is_device_accessible ={self.is_device_accessible },\n "
52
105
+ f" readonly={self.readonly},\n "
53
106
+ f" exporting_obj={get_simple_repr(self.exporting_obj)})" )
54
107
@@ -99,28 +152,25 @@ cdef class _StridedMemoryViewProxy:
99
152
100
153
cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
101
154
cdef int dldevice, device_id, i
102
- cdef bint device_accessible, versioned, is_readonly
155
+ cdef bint is_device_accessible, versioned, is_readonly
156
+ is_device_accessible = False
103
157
dldevice, device_id = obj.__dlpack_device__()
104
158
if dldevice == _kDLCPU:
105
- device_accessible = False
106
159
assert device_id == 0
160
+ device_id = - 1
107
161
if stream_ptr is None :
108
162
raise BufferError(" stream=None is ambiguous with view()" )
109
163
elif stream_ptr == - 1 :
110
164
stream_ptr = None
111
165
elif dldevice == _kDLCUDA:
112
- device_accessible = True
166
+ assert device_id >= 0
167
+ is_device_accessible = True
113
168
# no need to check other stream values, it's a pass-through
114
169
if stream_ptr is None :
115
170
raise BufferError(" stream=None is ambiguous with view()" )
116
- elif dldevice == _kDLCUDAHost:
117
- device_accessible = True
118
- assert device_id == 0
119
- # just do a pass-through without any checks, as pinned memory can be
120
- # accessed on both host and device
121
- elif dldevice == _kDLCUDAManaged:
122
- device_accessible = True
123
- # just do a pass-through without any checks, as managed memory can be
171
+ elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
172
+ is_device_accessible = True
173
+ # just do a pass-through without any checks, as pinned/managed memory can be
124
174
# accessed on both host and device
125
175
else :
126
176
raise BufferError(" device not supported" )
@@ -171,7 +221,7 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
171
221
buf.strides = None
172
222
buf.dtype = dtype_dlpack_to_numpy(& dl_tensor.dtype)
173
223
buf.device_id = device_id
174
- buf.device_accessible = device_accessible
224
+ buf.is_device_accessible = is_device_accessible
175
225
buf.readonly = is_readonly
176
226
buf.exporting_obj = obj
177
227
@@ -261,7 +311,7 @@ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
261
311
if buf.strides is not None :
262
312
# convert to counts
263
313
buf.strides = tuple (s // buf.dtype.itemsize for s in buf.strides)
264
- buf.device_accessible = True
314
+ buf.is_device_accessible = True
265
315
buf.device_id = handle_return(
266
316
cuda.cuPointerGetAttribute(
267
317
cuda.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
@@ -284,7 +334,34 @@ cdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
284
334
return buf
285
335
286
336
287
- def viewable (tuple arg_indices ):
337
+ def args_viewable_as_strided_memory (tuple arg_indices ):
338
+ """ Decorator to create proxy objects to :obj:`StridedMemoryView` for the
339
+ specified positional arguments.
340
+
341
+ This allows array/tensor attributes to be accessed inside the function
342
+ implementation, while keeping the function body array-library-agnostic (if
343
+ desired).
344
+
345
+ Inside the decorated function, the specified arguments become instances
346
+ of an (undocumented) proxy type, regardless of its original source. A
347
+ :obj:`StridedMemoryView` instance can be obtained by passing the (consumer)
348
+ stream pointer (as a Python `int`) to the proxies's ``view()`` method. For
349
+ example:
350
+
351
+ .. code-block:: python
352
+
353
+ @args_viewable_as_strided_memory((1,))
354
+ def my_func(arg0, arg1, arg2, stream: Stream):
355
+ # arg1 can be any object supporting DLPack or CUDA Array Interface
356
+ view = arg1.view(stream.handle)
357
+ assert isinstance(view, StridedMemoryView)
358
+ ...
359
+
360
+ Parameters
361
+ ----------
362
+ arg_indices : tuple
363
+ The indices of the target positional arguments.
364
+ """
288
365
def wrapped_func_with_indices (func ):
289
366
@ functools.wraps (func)
290
367
def wrapped_func (*args , **kwargs ):
0 commit comments