@@ -13,11 +13,7 @@ from libcpp cimport vector
13
13
14
14
import ctypes
15
15
16
- # this might be an unnecessary assumption that NumPy does not exist...
17
- try :
18
- import numpy
19
- except ImportError :
20
- numpy = None
16
+ import numpy
21
17
22
18
from cuda.core._memory import Buffer
23
19
@@ -43,7 +39,32 @@ ctypedef fused supported_type:
43
39
cpp_double_complex
44
40
45
41
46
- # TODO: cache ctypes/numpy type objects to avoid attribute access
42
+ # cache ctypes/numpy type objects to avoid attribute access
43
+ cdef object ctypes_bool = ctypes.c_bool
44
+ cdef object ctypes_int8 = ctypes.c_int8
45
+ cdef object ctypes_int16 = ctypes.c_int16
46
+ cdef object ctypes_int32 = ctypes.c_int32
47
+ cdef object ctypes_int64 = ctypes.c_int64
48
+ cdef object ctypes_uint8 = ctypes.c_uint8
49
+ cdef object ctypes_uint16 = ctypes.c_uint16
50
+ cdef object ctypes_uint32 = ctypes.c_uint32
51
+ cdef object ctypes_uint64 = ctypes.c_uint64
52
+ cdef object ctypes_float = ctypes.c_float
53
+ cdef object ctypes_double = ctypes.c_double
54
+ cdef object numpy_bool = numpy.bool_
55
+ cdef object numpy_int8 = numpy.int8
56
+ cdef object numpy_int16 = numpy.int16
57
+ cdef object numpy_int32 = numpy.int32
58
+ cdef object numpy_int64 = numpy.int64
59
+ cdef object numpy_uint8 = numpy.uint8
60
+ cdef object numpy_uint16 = numpy.uint16
61
+ cdef object numpy_uint32 = numpy.uint32
62
+ cdef object numpy_uint64 = numpy.uint64
63
+ cdef object numpy_float16 = numpy.float16
64
+ cdef object numpy_float32 = numpy.float32
65
+ cdef object numpy_float64 = numpy.float64
66
+ cdef object numpy_complex64 = numpy.complex64
67
+ cdef object numpy_complex128 = numpy.complex128
47
68
48
69
49
70
# limitation due to cython/cython#534
@@ -76,27 +97,27 @@ cdef inline int prepare_ctypes_arg(
76
97
vector.vector[void * ]& data_addresses,
77
98
arg,
78
99
const size_t idx) except - 1 :
79
- if isinstance (arg, ctypes.c_bool ):
100
+ if isinstance (arg, ctypes_bool ):
80
101
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
81
- elif isinstance (arg, ctypes.c_int8 ):
102
+ elif isinstance (arg, ctypes_int8 ):
82
103
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
83
- elif isinstance (arg, ctypes.c_int16 ):
104
+ elif isinstance (arg, ctypes_int16 ):
84
105
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
85
- elif isinstance (arg, ctypes.c_int32 ):
106
+ elif isinstance (arg, ctypes_int32 ):
86
107
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
87
- elif isinstance (arg, ctypes.c_int64 ):
108
+ elif isinstance (arg, ctypes_int64 ):
88
109
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
89
- elif isinstance (arg, ctypes.c_uint8 ):
110
+ elif isinstance (arg, ctypes_uint8 ):
90
111
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
91
- elif isinstance (arg, ctypes.c_uint16 ):
112
+ elif isinstance (arg, ctypes_uint16 ):
92
113
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
93
- elif isinstance (arg, ctypes.c_uint32 ):
114
+ elif isinstance (arg, ctypes_uint32 ):
94
115
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
95
- elif isinstance (arg, ctypes.c_uint64 ):
116
+ elif isinstance (arg, ctypes_uint64 ):
96
117
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
97
- elif isinstance (arg, ctypes.c_float ):
118
+ elif isinstance (arg, ctypes_float ):
98
119
return prepare_arg[float ](data, data_addresses, arg.value, idx)
99
- elif isinstance (arg, ctypes.c_double ):
120
+ elif isinstance (arg, ctypes_double ):
100
121
return prepare_arg[double ](data, data_addresses, arg.value, idx)
101
122
else :
102
123
return 1
@@ -107,37 +128,34 @@ cdef inline int prepare_numpy_arg(
107
128
vector.vector[void * ]& data_addresses,
108
129
arg,
109
130
const size_t idx) except - 1 :
110
- if not numpy:
111
- return 1
112
-
113
- if isinstance (arg, numpy.bool_):
131
+ if isinstance (arg, numpy_bool):
114
132
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
115
- elif isinstance (arg, numpy.int8 ):
133
+ elif isinstance (arg, numpy_int8 ):
116
134
return prepare_arg[int8_t](data, data_addresses, arg, idx)
117
- elif isinstance (arg, numpy.int16 ):
135
+ elif isinstance (arg, numpy_int16 ):
118
136
return prepare_arg[int16_t](data, data_addresses, arg, idx)
119
- elif isinstance (arg, numpy.int32 ):
137
+ elif isinstance (arg, numpy_int32 ):
120
138
return prepare_arg[int32_t](data, data_addresses, arg, idx)
121
- elif isinstance (arg, numpy.int64 ):
139
+ elif isinstance (arg, numpy_int64 ):
122
140
return prepare_arg[int64_t](data, data_addresses, arg, idx)
123
- elif isinstance (arg, numpy.uint8 ):
141
+ elif isinstance (arg, numpy_uint8 ):
124
142
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
125
- elif isinstance (arg, numpy.uint16 ):
143
+ elif isinstance (arg, numpy_uint16 ):
126
144
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
127
- elif isinstance (arg, numpy.uint32 ):
145
+ elif isinstance (arg, numpy_uint32 ):
128
146
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
129
- elif isinstance (arg, numpy.uint64 ):
147
+ elif isinstance (arg, numpy_uint64 ):
130
148
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
131
- elif isinstance (arg, numpy.float16 ):
149
+ elif isinstance (arg, numpy_float16 ):
132
150
# use int16 as a proxy
133
151
return prepare_arg[int16_t](data, data_addresses, arg, idx)
134
- elif isinstance (arg, numpy.float32 ):
152
+ elif isinstance (arg, numpy_float32 ):
135
153
return prepare_arg[float ](data, data_addresses, arg, idx)
136
- elif isinstance (arg, numpy.float64 ):
154
+ elif isinstance (arg, numpy_float64 ):
137
155
return prepare_arg[double ](data, data_addresses, arg, idx)
138
- elif isinstance (arg, numpy.complex64 ):
156
+ elif isinstance (arg, numpy_complex64 ):
139
157
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
140
- elif isinstance (arg, numpy.complex128 ):
158
+ elif isinstance (arg, numpy_complex128 ):
141
159
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
142
160
else :
143
161
return 1
@@ -185,9 +203,9 @@ cdef class ParamHolder:
185
203
continue
186
204
187
205
not_prepared = prepare_numpy_arg(self .data, self .data_addresses, arg, i)
188
- if not_prepared ! = 0 :
206
+ if not_prepared:
189
207
not_prepared = prepare_ctypes_arg(self .data, self .data_addresses, arg, i)
190
- if not_prepared ! = 0 :
208
+ if not_prepared:
191
209
# TODO: support ctypes/numpy struct
192
210
raise TypeError
193
211
0 commit comments