Skip to content

Commit 8c49acc

Browse files
committed
support dlpack 1.0
1 parent 16f541d commit 8c49acc

File tree

2 files changed

+90
-35
lines changed

2 files changed

+90
-35
lines changed

cuda_py/cuda/py/_dlpack.pyx

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@ cimport cpython
99
from libc cimport stdlib
1010
from libc.stdint cimport uint8_t
1111
from libc.stdint cimport uint16_t
12+
from libc.stdint cimport uint32_t
1213
from libc.stdint cimport int32_t
1314
from libc.stdint cimport int64_t
1415
from libc.stdint cimport uint64_t
1516
from libc.stdint cimport intptr_t
16-
from libcpp.vector cimport vector
1717

1818
from enum import IntEnum
1919

2020

2121
cdef extern from "dlpack.h" nogil:
22-
22+
"""
23+
#define DLPACK_TENSOR_UNUSED_NAME "dltensor"
24+
#define DLPACK_VERSIONED_TENSOR_UNUSED_NAME "dltensor_versioned"
25+
"""
2326
ctypedef enum _DLDeviceType "DLDeviceType":
2427
_kDLCPU "kDLCPU"
2528
_kDLCUDA "kDLCUDA"
@@ -52,33 +55,89 @@ cdef extern from "dlpack.h" nogil:
5255
void* manager_ctx
5356
void (*deleter)(DLManagedTensor*)
5457

58+
ctypedef struct DLPackVersion:
59+
uint32_t major
60+
uint32_t minor
61+
62+
ctypedef struct DLManagedTensorVersioned:
63+
DLPackVersion version
64+
void* manager_ctx
65+
void (*deleter)(DLManagedTensorVersioned*)
66+
uint64_t flags
67+
DLTensor dl_tensor
68+
69+
int DLPACK_MAJOR_VERSION
70+
int DLPACK_MINOR_VERSION
71+
72+
const char* DLPACK_TENSOR_UNUSED_NAME
73+
const char* DLPACK_VERSIONED_TENSOR_UNUSED_NAME
74+
5575

56-
cdef void pycapsule_deleter(object dltensor):
76+
cdef void pycapsule_deleter(object capsule):
5777
cdef DLManagedTensor* dlm_tensor
58-
# Do not invoke the deleter on a used capsule
59-
if cpython.PyCapsule_IsValid(dltensor, 'dltensor'):
60-
dlm_tensor = <DLManagedTensor*>cpython.PyCapsule_GetPointer(
61-
dltensor, 'dltensor')
62-
dlm_tensor.deleter(dlm_tensor)
78+
cdef DLManagedTensorVersioned* dlm_tensor_ver
79+
# Do not invoke the deleter on a used capsule.
80+
if cpython.PyCapsule_IsValid(
81+
capsule, DLPACK_TENSOR_UNUSED_NAME):
82+
dlm_tensor = <DLManagedTensor*>(
83+
cpython.PyCapsule_GetPointer(
84+
capsule, DLPACK_TENSOR_UNUSED_NAME))
85+
if dlm_tensor.deleter:
86+
dlm_tensor.deleter(dlm_tensor)
87+
elif cpython.PyCapsule_IsValid(
88+
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
89+
dlm_tensor_ver = <DLManagedTensorVersioned*>(
90+
cpython.PyCapsule_GetPointer(
91+
capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME))
92+
if dlm_tensor_ver.deleter:
93+
dlm_tensor_ver.deleter(dlm_tensor_ver)
6394

6495

6596
cdef void deleter(DLManagedTensor* tensor) with gil:
66-
if tensor.manager_ctx is NULL:
67-
return
6897
stdlib.free(tensor.dl_tensor.shape)
69-
cpython.Py_DECREF(<object>tensor.manager_ctx)
70-
tensor.manager_ctx = NULL
98+
if tensor.manager_ctx:
99+
cpython.Py_DECREF(<object>tensor.manager_ctx)
100+
tensor.manager_ctx = NULL
71101
stdlib.free(tensor)
72102

73103

74-
cpdef object make_py_capsule(object buf) except +:
75-
cdef DLManagedTensor* dlm_tensor = \
76-
<DLManagedTensor*>stdlib.malloc(sizeof(DLManagedTensor))
104+
cdef void versioned_deleter(DLManagedTensorVersioned* tensor) with gil:
105+
stdlib.free(tensor.dl_tensor.shape)
106+
if tensor.manager_ctx:
107+
cpython.Py_DECREF(<object>tensor.manager_ctx)
108+
tensor.manager_ctx = NULL
109+
stdlib.free(tensor)
110+
111+
112+
cpdef object make_py_capsule(object buf, bint versioned) except +:
113+
cdef DLManagedTensor* dlm_tensor
114+
cdef DLManagedTensorVersioned* dlm_tensor_ver
115+
cdef DLTensor* dl_tensor
116+
cdef void* tensor_ptr
117+
cdef const char* capsule_name
118+
119+
if versioned:
120+
dlm_tensor_ver = <DLManagedTensorVersioned*>(
121+
stdlib.malloc(sizeof(DLManagedTensorVersioned)))
122+
dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
123+
dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
124+
dlm_tensor_ver.manager_ctx = <void*>buf
125+
dlm_tensor_ver.deleter = versioned_deleter
126+
dlm_tensor_ver.flags = 0
127+
dl_tensor = &dlm_tensor_ver.dl_tensor
128+
tensor_ptr = dlm_tensor_ver
129+
capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
130+
else:
131+
dlm_tensor = <DLManagedTensor*>(
132+
stdlib.malloc(sizeof(DLManagedTensor)))
133+
dl_tensor = &dlm_tensor.dl_tensor
134+
dlm_tensor.manager_ctx = <void*>buf
135+
dlm_tensor.deleter = deleter
136+
tensor_ptr = dlm_tensor
137+
capsule_name = DLPACK_TENSOR_UNUSED_NAME
77138

78-
cdef DLTensor* dl_tensor = &dlm_tensor.dl_tensor
79139
dl_tensor.data = <void*><intptr_t>(int(buf.handle))
80140
dl_tensor.ndim = 1
81-
82141
cdef int64_t* shape_strides = \
83142
<int64_t*>stdlib.malloc(sizeof(int64_t) * 2)
84143
shape_strides[0] = <int64_t>buf.size
@@ -106,11 +165,8 @@ cpdef object make_py_capsule(object buf) except +:
106165
dtype.lanes = <uint16_t>1
107166
dtype.bits = <uint8_t>8
108167

109-
dlm_tensor.manager_ctx = <void*>buf
110168
cpython.Py_INCREF(buf)
111-
dlm_tensor.deleter = deleter
112-
113-
return cpython.PyCapsule_New(dlm_tensor, 'dltensor', pycapsule_deleter)
169+
return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
114170

115171

116172
class DLDeviceType(IntEnum):

cuda_py/cuda/py/_memory.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,19 @@ def __dlpack__(self, *,
107107
max_version: Optional[Tuple[int, int]] = None,
108108
dl_device: Optional[Tuple[int, int]] = None,
109109
copy: Optional[bool] = None) -> PyCapsule:
110-
# Support for Python-level DLPack protocol.
111-
if stream is not None:
112-
warnings.warn("stream != None is ignored")
113-
# TODO: add checks for dl_device and copy
114-
# FIXME: fix v1.0 support
115-
#if max_version is None:
116-
# versioned = False
117-
#else:
118-
# assert len(max_version) == 2
119-
# if max_version >= (1, 0):
120-
# versioned = True
121-
# else:
122-
# versioned = False
123-
capsule = make_py_capsule(self)#, versioned)
110+
# Note: we ignore the stream argument entirely (as if it is -1).
111+
# It is the user's responsibility to maintain stream order.
112+
if dl_device is not None or copy is True:
113+
raise BufferError
114+
if max_version is None:
115+
versioned = False
116+
else:
117+
assert len(max_version) == 2
118+
if max_version >= (1, 0):
119+
versioned = True
120+
else:
121+
versioned = False
122+
capsule = make_py_capsule(self, versioned)
124123
return capsule
125124

126125
def __dlpack_device__(self) -> Tuple[int, int]:

0 commit comments

Comments
 (0)