@@ -9,17 +9,20 @@ cimport cpython
9
9
from libc cimport stdlib
10
10
from libc.stdint cimport uint8_t
11
11
from libc.stdint cimport uint16_t
12
+ from libc.stdint cimport uint32_t
12
13
from libc.stdint cimport int32_t
13
14
from libc.stdint cimport int64_t
14
15
from libc.stdint cimport uint64_t
15
16
from libc.stdint cimport intptr_t
16
- from libcpp.vector cimport vector
17
17
18
18
from enum import IntEnum
19
19
20
20
21
21
cdef extern from " dlpack.h" nogil:
22
-
22
+ """
23
+ #define DLPACK_TENSOR_UNUSED_NAME "dltensor"
24
+ #define DLPACK_VERSIONED_TENSOR_UNUSED_NAME "dltensor_versioned"
25
+ """
23
26
ctypedef enum _DLDeviceType " DLDeviceType" :
24
27
_kDLCPU " kDLCPU"
25
28
_kDLCUDA " kDLCUDA"
@@ -52,33 +55,89 @@ cdef extern from "dlpack.h" nogil:
52
55
void * manager_ctx
53
56
void (* deleter)(DLManagedTensor* )
54
57
58
+ ctypedef struct DLPackVersion:
59
+ uint32_t major
60
+ uint32_t minor
61
+
62
+ ctypedef struct DLManagedTensorVersioned:
63
+ DLPackVersion version
64
+ void * manager_ctx
65
+ void (* deleter)(DLManagedTensorVersioned* )
66
+ uint64_t flags
67
+ DLTensor dl_tensor
68
+
69
+ int DLPACK_MAJOR_VERSION
70
+ int DLPACK_MINOR_VERSION
71
+
72
+ const char * DLPACK_TENSOR_UNUSED_NAME
73
+ const char * DLPACK_VERSIONED_TENSOR_UNUSED_NAME
74
+
55
75
56
- cdef void pycapsule_deleter(object dltensor ):
76
+ cdef void pycapsule_deleter(object capsule ):
57
77
cdef DLManagedTensor* dlm_tensor
58
- # Do not invoke the deleter on a used capsule
59
- if cpython.PyCapsule_IsValid(dltensor, ' dltensor' ):
60
- dlm_tensor = < DLManagedTensor* > cpython.PyCapsule_GetPointer(
61
- dltensor, ' dltensor' )
62
- dlm_tensor.deleter(dlm_tensor)
78
+ cdef DLManagedTensorVersioned* dlm_tensor_ver
79
+ # Do not invoke the deleter on a used capsule.
80
+ if cpython.PyCapsule_IsValid(
81
+ capsule, DLPACK_TENSOR_UNUSED_NAME):
82
+ dlm_tensor = < DLManagedTensor* > (
83
+ cpython.PyCapsule_GetPointer(
84
+ capsule, DLPACK_TENSOR_UNUSED_NAME))
85
+ if dlm_tensor.deleter:
86
+ dlm_tensor.deleter(dlm_tensor)
87
+ elif cpython.PyCapsule_IsValid(
88
+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME):
89
+ dlm_tensor_ver = < DLManagedTensorVersioned* > (
90
+ cpython.PyCapsule_GetPointer(
91
+ capsule, DLPACK_VERSIONED_TENSOR_UNUSED_NAME))
92
+ if dlm_tensor_ver.deleter:
93
+ dlm_tensor_ver.deleter(dlm_tensor_ver)
63
94
64
95
65
96
cdef void deleter(DLManagedTensor* tensor) with gil:
66
- if tensor.manager_ctx is NULL :
67
- return
68
97
stdlib.free(tensor.dl_tensor.shape)
69
- cpython.Py_DECREF(< object > tensor.manager_ctx)
70
- tensor.manager_ctx = NULL
98
+ if tensor.manager_ctx:
99
+ cpython.Py_DECREF(< object > tensor.manager_ctx)
100
+ tensor.manager_ctx = NULL
71
101
stdlib.free(tensor)
72
102
73
103
74
- cpdef object make_py_capsule(object buf) except + :
75
- cdef DLManagedTensor* dlm_tensor = \
76
- < DLManagedTensor* > stdlib.malloc(sizeof(DLManagedTensor))
104
+ cdef void versioned_deleter(DLManagedTensorVersioned* tensor) with gil:
105
+ stdlib.free(tensor.dl_tensor.shape)
106
+ if tensor.manager_ctx:
107
+ cpython.Py_DECREF(< object > tensor.manager_ctx)
108
+ tensor.manager_ctx = NULL
109
+ stdlib.free(tensor)
110
+
111
+
112
+ cpdef object make_py_capsule(object buf, bint versioned) except + :
113
+ cdef DLManagedTensor* dlm_tensor
114
+ cdef DLManagedTensorVersioned* dlm_tensor_ver
115
+ cdef DLTensor* dl_tensor
116
+ cdef void * tensor_ptr
117
+ cdef const char * capsule_name
118
+
119
+ if versioned:
120
+ dlm_tensor_ver = < DLManagedTensorVersioned* > (
121
+ stdlib.malloc(sizeof(DLManagedTensorVersioned)))
122
+ dlm_tensor_ver.version.major = DLPACK_MAJOR_VERSION
123
+ dlm_tensor_ver.version.minor = DLPACK_MINOR_VERSION
124
+ dlm_tensor_ver.manager_ctx = < void * > buf
125
+ dlm_tensor_ver.deleter = versioned_deleter
126
+ dlm_tensor_ver.flags = 0
127
+ dl_tensor = & dlm_tensor_ver.dl_tensor
128
+ tensor_ptr = dlm_tensor_ver
129
+ capsule_name = DLPACK_VERSIONED_TENSOR_UNUSED_NAME
130
+ else :
131
+ dlm_tensor = < DLManagedTensor* > (
132
+ stdlib.malloc(sizeof(DLManagedTensor)))
133
+ dl_tensor = & dlm_tensor.dl_tensor
134
+ dlm_tensor.manager_ctx = < void * > buf
135
+ dlm_tensor.deleter = deleter
136
+ tensor_ptr = dlm_tensor
137
+ capsule_name = DLPACK_TENSOR_UNUSED_NAME
77
138
78
- cdef DLTensor* dl_tensor = & dlm_tensor.dl_tensor
79
139
dl_tensor.data = < void * >< intptr_t> (int (buf.handle))
80
140
dl_tensor.ndim = 1
81
-
82
141
cdef int64_t* shape_strides = \
83
142
< int64_t* > stdlib.malloc(sizeof(int64_t) * 2 )
84
143
shape_strides[0 ] = < int64_t> buf.size
@@ -106,11 +165,8 @@ cpdef object make_py_capsule(object buf) except +:
106
165
dtype.lanes = < uint16_t> 1
107
166
dtype.bits = < uint8_t> 8
108
167
109
- dlm_tensor.manager_ctx = < void * > buf
110
168
cpython.Py_INCREF(buf)
111
- dlm_tensor.deleter = deleter
112
-
113
- return cpython.PyCapsule_New(dlm_tensor, ' dltensor' , pycapsule_deleter)
169
+ return cpython.PyCapsule_New(tensor_ptr, capsule_name, pycapsule_deleter)
114
170
115
171
116
172
class DLDeviceType (IntEnum ):
0 commit comments