@@ -88,20 +88,145 @@ cdef class StridedMemoryView:
8888 cdef DLTensor * dl_tensor
8989
9090 # Memoized properties
91- cdef tuple _shape
92- cdef tuple _strides
93- cdef bint _strides_init # Has the strides tuple been init'ed?
94- cdef object _dtype
95-
96- def __init__ (self , obj = None , stream_ptr = None ):
97- if obj is not None :
98- # populate self's attributes
99- if check_has_dlpack(obj):
100- view_as_dlpack(obj, stream_ptr, self )
101- else :
102- view_as_cai(obj, stream_ptr, self )
91+ cdef:
92+ tuple _shape
93+ tuple _strides
94+ bint _strides_init # Has the strides tuple been init'ed?
95+ object _dtype
96+
97+ def __init__ (
98+ self ,
99+ *,
100+ ptr: int ,
101+ device_id: int ,
102+ is_device_accessible: bool ,
103+ readonly: bool ,
104+ metadata: object ,
105+ exporting_obj: object ,
106+ dl_tensor: intptr_t = - 1 ,
107+ ) -> None:
108+ self.ptr = ptr
109+ self.device_id = device_id
110+ self.is_device_accessible = is_device_accessible
111+ self.readonly = readonly
112+ self.metadata = metadata
113+ self.exporting_obj = exporting_obj
114+ self.dl_tensor = (< DLTensor* > NULL ) if dl_tensor is - 1 else (< DLTensor* > dl_tensor)
115+
116+ @classmethod
117+ def from_dlpack(cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
118+ cdef int dldevice , device_id
119+ cdef bint is_device_accessible , is_readonly
120+ is_device_accessible = False
121+ dldevice , device_id = obj.__dlpack_device__()
122+ if dldevice == _kDLCPU:
123+ assert device_id == 0
124+ device_id = - 1
125+ if stream_ptr is None:
126+ raise BufferError("stream = None is ambiguous with view()" )
127+ elif stream_ptr == -1:
128+ stream_ptr = None
129+ elif dldevice == _kDLCUDA:
130+ assert device_id >= 0
131+ is_device_accessible = True
132+ # no need to check other stream values , it's a pass-through
133+ if stream_ptr is None:
134+ raise BufferError("stream = None is ambiguous with view()" )
135+ elif dldevice in (_kDLCUDAHost , _kDLCUDAManaged ):
136+ is_device_accessible = True
137+ # just do a pass-through without any checks, as pinned/managed memory can be
138+ # accessed on both host and device
103139 else :
104- pass
140+ raise BufferError(" device not supported" )
141+
142+ cdef object capsule
143+ try :
144+ capsule = obj.__dlpack__(
145+ stream = int (stream_ptr) if stream_ptr else None ,
146+ max_version = (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
147+ except TypeError :
148+ capsule = obj.__dlpack__(
149+ stream = int (stream_ptr) if stream_ptr else None )
150+
151+ cdef void * data = NULL
152+ cdef DLTensor* dl_tensor
153+ cdef DLManagedTensorVersioned* dlm_tensor_ver
154+ cdef DLManagedTensor* dlm_tensor
155+ cdef const char * used_name
156+ if cpython.PyCapsule_IsValid(
157+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
158+ data = cpython.PyCapsule_GetPointer(
159+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
160+ dlm_tensor_ver = < DLManagedTensorVersioned* > data
161+ dl_tensor = & dlm_tensor_ver.dl_tensor
162+ is_readonly = (dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0
163+ used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
164+ else :
165+ assert cpython.PyCapsule_IsValid(
166+ capsule, DLPACK_TENSOR_UNUSED_NAME)
167+ data = cpython.PyCapsule_GetPointer(
168+ capsule, DLPACK_TENSOR_UNUSED_NAME)
169+ dlm_tensor = < DLManagedTensor* > data
170+ dl_tensor = & dlm_tensor.dl_tensor
171+ is_readonly = False
172+ used_name = DLPACK_TENSOR_USED_NAME
173+
174+ cpython.PyCapsule_SetName(capsule, used_name)
175+
176+ return cls (
177+ ptr = < intptr_t> dl_tensor.data,
178+ device_id = int (device_id),
179+ is_device_accessible = is_device_accessible,
180+ readonly = is_readonly,
181+ metadata = capsule,
182+ exporting_obj = obj,
183+ dl_tensor = < intptr_t> dl_tensor,
184+ )
185+
186+ @classmethod
187+ def from_any_interface (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
188+ if check_has_dlpack(obj ):
189+ return cls .from_dlpack(obj, stream_ptr)
190+ return cls .from_cuda_array_interface(obj, stream_ptr)
191+
192+ @classmethod
193+ def from_cuda_array_interface (cls , obj: object , stream_ptr: int | None = None ) -> StridedMemoryView:
194+ cdef dict cai_data = obj.__cuda_array_interface__
195+ if cai_data["version"] < 3:
196+ raise BufferError("only CUDA Array Interface v3 or above is supported")
197+ if cai_data.get("mask") is not None:
198+ raise BufferError("mask is not supported")
199+ if stream_ptr is None:
200+ raise BufferError("stream = None is ambiguous with view()" )
201+
202+ cdef intptr_t producer_s , consumer_s
203+ stream_ptr = int (stream_ptr)
204+ if stream_ptr != -1:
205+ stream = cai_data.get(" stream" )
206+ if stream is not None:
207+ producer_s = < intptr_t> (stream)
208+ consumer_s = < intptr_t> (stream_ptr)
209+ assert producer_s > 0
210+ # establish stream order
211+ if producer_s != consumer_s:
212+ e = handle_return(driver.cuEventCreate(
213+ driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
214+ handle_return(driver.cuEventRecord(e , producer_s ))
215+ handle_return(driver.cuStreamWaitEvent(consumer_s , e , 0))
216+ handle_return(driver.cuEventDestroy(e ))
217+ return cls(
218+ ptr = int (cai_data[" data" ][0 ]),
219+ device_id = handle_return(
220+ driver.cuPointerGetAttribute(
221+ driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
222+ buf.ptr
223+ )
224+ ),
225+ is_device_accessible = True ,
226+ readonly = cai_data[" data" ][1 ],
227+ metadata = cai_data,
228+ exporting_obj = obj,
229+ )
105230
106231 def __dealloc__(self ):
107232 if self .dl_tensor == NULL :
@@ -206,8 +331,7 @@ cdef bint check_has_dlpack(obj) except*:
206331
207332
208333cdef class _StridedMemoryViewProxy:
209-
210- cdef:
334+ cdef readonly:
211335 object obj
212336 bint has_dlpack
213337
@@ -217,82 +341,11 @@ cdef class _StridedMemoryViewProxy:
217341
218342 cpdef StridedMemoryView view(self , stream_ptr = None ):
219343 if self .has_dlpack:
220- return view_as_dlpack (self .obj, stream_ptr)
344+ return StridedMemoryView.from_dlpack (self .obj, stream_ptr)
221345 else :
222- return view_as_cai (self .obj, stream_ptr)
346+ return StridedMemoryView.from_cuda_array_interface (self .obj, stream_ptr)
223347
224348
225- cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view = None ):
226- cdef int dldevice, device_id
227- cdef bint is_device_accessible, is_readonly
228- is_device_accessible = False
229- dldevice, device_id = obj.__dlpack_device__()
230- if dldevice == _kDLCPU:
231- assert device_id == 0
232- device_id = - 1
233- if stream_ptr is None :
234- raise BufferError(" stream=None is ambiguous with view()" )
235- elif stream_ptr == - 1 :
236- stream_ptr = None
237- elif dldevice == _kDLCUDA:
238- assert device_id >= 0
239- is_device_accessible = True
240- # no need to check other stream values, it's a pass-through
241- if stream_ptr is None :
242- raise BufferError(" stream=None is ambiguous with view()" )
243- elif dldevice in (_kDLCUDAHost, _kDLCUDAManaged):
244- is_device_accessible = True
245- # just do a pass-through without any checks, as pinned/managed memory can be
246- # accessed on both host and device
247- else :
248- raise BufferError(" device not supported" )
249-
250- cdef object capsule
251- try :
252- capsule = obj.__dlpack__(
253- stream = int (stream_ptr) if stream_ptr else None ,
254- max_version = (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
255- except TypeError :
256- capsule = obj.__dlpack__(
257- stream = int (stream_ptr) if stream_ptr else None )
258-
259- cdef void * data = NULL
260- cdef DLTensor* dl_tensor
261- cdef DLManagedTensorVersioned* dlm_tensor_ver
262- cdef DLManagedTensor* dlm_tensor
263- cdef const char * used_name
264- if cpython.PyCapsule_IsValid(
265- capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
266- data = cpython.PyCapsule_GetPointer(
267- capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME)
268- dlm_tensor_ver = < DLManagedTensorVersioned* > data
269- dl_tensor = & dlm_tensor_ver.dl_tensor
270- is_readonly = bool ((dlm_tensor_ver.flags & DLPACK_FLAG_BITMASK_READ_ONLY) != 0 )
271- used_name = DLPACK_VERSIONED_TENSOR_USED_NAME
272- elif cpython.PyCapsule_IsValid(
273- capsule, DLPACK_TENSOR_UNUSED_NAME):
274- data = cpython.PyCapsule_GetPointer(
275- capsule, DLPACK_TENSOR_UNUSED_NAME)
276- dlm_tensor = < DLManagedTensor* > data
277- dl_tensor = & dlm_tensor.dl_tensor
278- is_readonly = False
279- used_name = DLPACK_TENSOR_USED_NAME
280- else :
281- assert False
282-
283- cpython.PyCapsule_SetName(capsule, used_name)
284-
285- cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
286- buf.dl_tensor = dl_tensor
287- buf.metadata = capsule
288- buf.ptr = < intptr_t> (dl_tensor.data)
289- buf.device_id = device_id
290- buf.is_device_accessible = is_device_accessible
291- buf.readonly = is_readonly
292- buf.exporting_obj = obj
293-
294- return buf
295-
296349
297350cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
298351 cdef int bits = dtype.bits
@@ -354,46 +407,6 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
354407 return numpy.dtype(np_dtype)
355408
356409
357- # Also generate for Python so we can test this code path
358- cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view = None ):
359- cdef dict cai_data = obj.__cuda_array_interface__
360- if cai_data[" version" ] < 3 :
361- raise BufferError(" only CUDA Array Interface v3 or above is supported" )
362- if cai_data.get(" mask" ) is not None :
363- raise BufferError(" mask is not supported" )
364- if stream_ptr is None :
365- raise BufferError(" stream=None is ambiguous with view()" )
366-
367- cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
368- buf.exporting_obj = obj
369- buf.metadata = cai_data
370- buf.dl_tensor = NULL
371- buf.ptr, buf.readonly = cai_data[" data" ]
372- buf.is_device_accessible = True
373- buf.device_id = handle_return(
374- driver.cuPointerGetAttribute(
375- driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
376- buf.ptr))
377-
378- cdef intptr_t producer_s, consumer_s
379- stream_ptr = int (stream_ptr)
380- if stream_ptr != - 1 :
381- stream = cai_data.get(" stream" )
382- if stream is not None :
383- producer_s = < intptr_t> (stream)
384- consumer_s = < intptr_t> (stream_ptr)
385- assert producer_s > 0
386- # establish stream order
387- if producer_s != consumer_s:
388- e = handle_return(driver.cuEventCreate(
389- driver.CUevent_flags.CU_EVENT_DISABLE_TIMING))
390- handle_return(driver.cuEventRecord(e, producer_s))
391- handle_return(driver.cuStreamWaitEvent(consumer_s, e, 0 ))
392- handle_return(driver.cuEventDestroy(e))
393-
394- return buf
395-
396-
397410def args_viewable_as_strided_memory (tuple arg_indices ):
398411 """
399412 Decorator to create proxy objects to :obj:`StridedMemoryView` for the
0 commit comments