Skip to content

Commit 786140a

Browse files
committed
feat: add from_* style constructor classmethods to StridedMemoryView and make constructor amenable to future from_*-style constructors
1 parent 69aac67 commit 786140a

File tree

3 files changed

+146
-134
lines changed

3 files changed

+146
-134
lines changed

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 141 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,145 @@ cdef class StridedMemoryView:
8888
cdef DLTensor *dl_tensor
8989

9090
# Memoized properties
91-
cdef tuple _shape
92-
cdef tuple _strides
93-
cdef bint _strides_init # Has the strides tuple been init'ed?
94-
cdef object _dtype
95-
96-
def __init__(self, obj=None, stream_ptr=None):
97-
if obj is not None:
98-
# populate self's attributes
99-
if check_has_dlpack(obj):
100-
view_as_dlpack(obj, stream_ptr, self)
101-
else:
102-
view_as_cai(obj, stream_ptr, self)
91+
cdef:
92+
tuple _shape
93+
tuple _strides
94+
bint _strides_init # Has the strides tuple been init'ed?
95+
object _dtype
96+
97+
def __init__(
98+
self,
99+
*,
100+
ptr: int,
101+
device_id: int,
102+
is_device_accessible: bool,
103+
readonly: bool,
104+
metadata: object,
105+
exporting_obj: object,
106+
dl_tensor: intptr_t = -1,
107+
) -> None:
108+
self.ptr = ptr
109+
self.device_id = device_id
110+
self.is_device_accessible = is_device_accessible
111+
self.readonly = readonly
112+
self.metadata = metadata
113+
self.exporting_obj = exporting_obj
114+
self.dl_tensor = (<DLTensor*>NULL) if dl_tensor is -1 else (<DLTensor*>dl_tensor)
115+
116+
@classmethod
117+
def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
118+
cdef int dldevice, device_id
119+
cdef bint is_device_accessible, is_readonly
120+
is_device_accessible = False
121+
dldevice, device_id = obj.__dlpack_device__()
122+
if dldevice == _kDLCPU:
123+
assert device_id == 0
124+
device_id = -1
125+
if stream_ptr is None:
126+
raise BufferError("stream=None is ambiguous with view()")
127+
elif stream_ptr == -1:
128+
stream_ptr = None
129+
elif dldevice == _kDLCUDA:
130+
assert device_id >= 0
131+
is_device_accessible = True
132+
# no need to check other stream values, it's a pass-through
133+
if stream_ptr is None:
134+
raise BufferError("stream=None is ambiguous with view()")
135+
elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
136+
is_device_accessible = True
137+
# just do a pass-through without any checks, as pinned/managed memory can be
138+
# accessed on both host and device
103139
else:
104-
pass
140+
raise BufferError("device not supported")
141+
142+
cdef object capsule
143+
try:
144+
capsule = obj.__dlpack__(
145+
stream=int(stream_ptr) if stream_ptr else None,
146+
max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
147+
except TypeError:
148+
capsule = obj.__dlpack__(
149+
stream=int(stream_ptr) if stream_ptr else None)
150+
151+
cdef void* data = NULL
152+
cdef DLTensor* dl_tensor
153+
cdef DLManagedTensorVersioned* dlm_tensor_ver
154+
cdef DLManagedTensor* dlm_tensor
155+
cdef const char *used_name
156+
if cpython.PyCapsule_IsValid(
157+
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
158+
data = cpython.PyCapsule_GetPointer(
159+
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
160+
dlm_tensor_ver = <DLManagedTensorVersioned*>data
161+
dl_tensor = &dlm_tensor_ver.dl_tensor
162+
is_readonly = (dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0
163+
used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
164+
else:
165+
assert cpython.PyCapsule_IsValid(
166+
capsule, DLPACK_TENSOR_UNUSED_NAME)
167+
data = cpython.PyCapsule_GetPointer(
168+
capsule, DLPACK_TENSOR_UNUSED_NAME)
169+
dlm_tensor = <DLManagedTensor*>data
170+
dl_tensor = &dlm_tensor.dl_tensor
171+
is_readonly = False
172+
used_name = DLPACK_TENSOR_USED_NAME
173+
174+
cpython.PyCapsule_SetName(capsule, used_name)
175+
176+
return cls(
177+
ptr=<intptr_t>dl_tensor.data,
178+
device_id=int(device_id),
179+
is_device_accessible=is_device_accessible,
180+
readonly=is_readonly,
181+
metadata=capsule,
182+
exporting_obj=obj,
183+
dl_tensor=<intptr_t>dl_tensor,
184+
)
185+
186+
@classmethod
187+
def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
188+
if check_has_dlpack(obj):
189+
return cls.from_dlpack(obj, stream_ptr)
190+
return cls.from_cuda_array_interface(obj, stream_ptr)
191+
192+
@classmethod
193+
def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
194+
cdef dict cai_data = obj.__cuda_array_interface__
195+
if cai_data["version"] < 3:
196+
raise BufferError("only CUDA Array Interface v3 or above is supported")
197+
if cai_data.get("mask") is not None:
198+
raise BufferError("mask is not supported")
199+
if stream_ptr is None:
200+
raise BufferError("stream=None is ambiguous with view()")
201+
202+
cdef intptr_t producer_s, consumer_s
203+
stream_ptr = int(stream_ptr)
204+
if stream_ptr != -1:
205+
stream = cai_data.get("stream")
206+
if stream is not None:
207+
producer_s = <intptr_t>(stream)
208+
consumer_s = <intptr_t>(stream_ptr)
209+
assert producer_s > 0
210+
# establish stream order
211+
if producer_s != consumer_s:
212+
e = handle_return(driver.cuEventCreate(
213+
driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
214+
handle_return(driver.cuEventRecord(e, producer_s))
215+
handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0))
216+
handle_return(driver.cuEventDestroy(e))
217+
return cls(
218+
ptr=int(cai_data["data"][0]),
219+
device_id=handle_return(
220+
driver.cuPointerGetAttribute(
221+
driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
222+
buf.ptr
223+
)
224+
),
225+
is_device_accessible=True,
226+
readonly=cai_data["data"][1],
227+
metadata=cai_data,
228+
exporting_obj=obj,
229+
)
105230

106231
def __dealloc__(self):
107232
if self.dl_tensor == NULL:
@@ -206,8 +331,7 @@ cdef bint check_has_dlpack(obj) except*:
206331

207332

208333
cdef class _StridedMemoryViewProxy:
209-
210-
cdef:
334+
cdef readonly:
211335
object obj
212336
bint has_dlpack
213337

@@ -217,82 +341,11 @@ cdef class _StridedMemoryViewProxy:
217341

218342
cpdef StridedMemoryView view(self, stream_ptr=None):
219343
if self.has_dlpack:
220-
return view_as_dlpack(self.obj, stream_ptr)
344+
return StridedMemoryView.from_dlpack(self.obj, stream_ptr)
221345
else:
222-
return view_as_cai(self.obj, stream_ptr)
346+
return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr)
223347

224348

225-
cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
226-
cdef int dldevice, device_id
227-
cdef bint is_device_accessible, is_readonly
228-
is_device_accessible = False
229-
dldevice, device_id = obj.__dlpack_device__()
230-
if dldevice == _kDLCPU:
231-
assert device_id == 0
232-
device_id = -1
233-
if stream_ptr is None:
234-
raise BufferError("stream=None is ambiguous with view()")
235-
elif stream_ptr == -1:
236-
stream_ptr = None
237-
elif dldevice == _kDLCUDA:
238-
assert device_id >= 0
239-
is_device_accessible = True
240-
# no need to check other stream values, it's a pass-through
241-
if stream_ptr is None:
242-
raise BufferError("stream=None is ambiguous with view()")
243-
elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
244-
is_device_accessible = True
245-
# just do a pass-through without any checks, as pinned/managed memory can be
246-
# accessed on both host and device
247-
else:
248-
raise BufferError("device not supported")
249-
250-
cdef object capsule
251-
try:
252-
capsule = obj.__dlpack__(
253-
stream=int(stream_ptr) if stream_ptr else None,
254-
max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
255-
except TypeError:
256-
capsule = obj.__dlpack__(
257-
stream=int(stream_ptr) if stream_ptr else None)
258-
259-
cdef void* data = NULL
260-
cdef DLTensor* dl_tensor
261-
cdef DLManagedTensorVersioned* dlm_tensor_ver
262-
cdef DLManagedTensor* dlm_tensor
263-
cdef const char *used_name
264-
if cpython.PyCapsule_IsValid(
265-
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
266-
data = cpython.PyCapsule_GetPointer(
267-
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
268-
dlm_tensor_ver = <DLManagedTensorVersioned*>data
269-
dl_tensor = &dlm_tensor_ver.dl_tensor
270-
is_readonly = bool((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0)
271-
used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
272-
elif cpython.PyCapsule_IsValid(
273-
capsule, DLPACK_TENSOR_UNUSED_NAME):
274-
data = cpython.PyCapsule_GetPointer(
275-
capsule, DLPACK_TENSOR_UNUSED_NAME)
276-
dlm_tensor = <DLManagedTensor*>data
277-
dl_tensor = &dlm_tensor.dl_tensor
278-
is_readonly = False
279-
used_name = DLPACK_TENSOR_USED_NAME
280-
else:
281-
assert False
282-
283-
cpython.PyCapsule_SetName(capsule, used_name)
284-
285-
cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
286-
buf.dl_tensor = dl_tensor
287-
buf.metadata = capsule
288-
buf.ptr = <intptr_t>(dl_tensor.data)
289-
buf.device_id = device_id
290-
buf.is_device_accessible = is_device_accessible
291-
buf.readonly = is_readonly
292-
buf.exporting_obj = obj
293-
294-
return buf
295-
296349

297350
cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
298351
cdef int bits = dtype.bits
@@ -354,46 +407,6 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
354407
return numpy.dtype(np_dtype)
355408

356409

357-
# Also generate for Python so we can test this code path
358-
cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
359-
cdef dict cai_data = obj.__cuda_array_interface__
360-
if cai_data["version"] < 3:
361-
raise BufferError("only CUDA Array Interface v3 or above is supported")
362-
if cai_data.get("mask") is not None:
363-
raise BufferError("mask is not supported")
364-
if stream_ptr is None:
365-
raise BufferError("stream=None is ambiguous with view()")
366-
367-
cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
368-
buf.exporting_obj = obj
369-
buf.metadata = cai_data
370-
buf.dl_tensor = NULL
371-
buf.ptr, buf.readonly = cai_data["data"]
372-
buf.is_device_accessible = True
373-
buf.device_id = handle_return(
374-
driver.cuPointerGetAttribute(
375-
driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
376-
buf.ptr))
377-
378-
cdef intptr_t producer_s, consumer_s
379-
stream_ptr = int(stream_ptr)
380-
if stream_ptr != -1:
381-
stream = cai_data.get("stream")
382-
if stream is not None:
383-
producer_s = <intptr_t>(stream)
384-
consumer_s = <intptr_t>(stream_ptr)
385-
assert producer_s > 0
386-
# establish stream order
387-
if producer_s != consumer_s:
388-
e = handle_return(driver.cuEventCreate(
389-
driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
390-
handle_return(driver.cuEventRecord(e, producer_s))
391-
handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0))
392-
handle_return(driver.cuEventDestroy(e))
393-
394-
return buf
395-
396-
397410
def args_viewable_as_strided_memory(tuple arg_indices):
398411
"""
399412
Decorator to create proxy objects to :obj:`StridedMemoryView` for the

cuda_core/tests/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -609,15 +609,15 @@ def test_strided_memory_view_leak():
609609
arr = np.zeros(1048576, dtype=np.uint8)
610610
before = sys.getrefcount(arr)
611611
for idx in range(10):
612-
StridedMemoryView(arr, stream_ptr=-1)
612+
StridedMemoryView.from_any_interface(arr, stream_ptr=-1)
613613
after = sys.getrefcount(arr)
614614
assert before == after
615615

616616

617617
def test_strided_memory_view_refcnt():
618618
# Use Fortran ordering so strides is used
619619
a = np.zeros((64, 4), dtype=np.uint8, order="F")
620-
av = StridedMemoryView(a, stream_ptr=-1)
620+
av = StridedMemoryView.from_any_interface(a, stream_ptr=-1)
621621
# segfaults if refcnt is wrong
622622
assert av.shape[0] == 64
623623
assert sys.getrefcount(av.shape) >= 2

cuda_core/tests/test_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515
import pytest
1616
from cuda.core.experimental import Device
17-
from cuda.core.experimental._memoryview import view_as_cai
1817
from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory
1918

2019

@@ -78,7 +77,7 @@ def my_func(arr):
7877

7978
def test_strided_memory_view_cpu(self, in_arr):
8079
# stream_ptr=-1 means "the consumer does not care"
81-
view = StridedMemoryView(in_arr, stream_ptr=-1)
80+
view = StridedMemoryView.from_any_interface(in_arr, stream_ptr=-1)
8281
self._check_view(view, in_arr)
8382

8483
def _check_view(self, view, in_arr):
@@ -147,7 +146,7 @@ def test_strided_memory_view_cpu(self, in_arr, use_stream):
147146
# This is the consumer stream
148147
s = dev.create_stream() if use_stream else None
149148

150-
view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
149+
view = StridedMemoryView.from_any_interface(in_arr, stream_ptr=s.handle if s else -1)
151150
self._check_view(view, in_arr, dev)
152151

153152
def _check_view(self, view, in_arr, dev):
@@ -179,7 +178,7 @@ def test_cuda_array_interface_gpu(self, in_arr, use_stream):
179178
# The usual path in `StridedMemoryView` prefers the DLPack interface
180179
# over __cuda_array_interface__, so we call `view_as_cai` directly
181180
# here so we can test the CAI code path.
182-
view = view_as_cai(in_arr, stream_ptr=s.handle if s else -1)
181+
view = StridedMemoryView.from_cuda_array_interface(in_arr, stream_ptr=s.handle if s else -1)
183182
self._check_view(view, in_arr, dev)
184183

185184
def _check_view(self, view, in_arr, dev):

0 commit comments

Comments
 (0)