Skip to content

Commit df017dd

Browse files
committed
micro optimization + make numpy as required
1 parent f0c155c commit df017dd

File tree

3 files changed

+61
-36
lines changed

3 files changed

+61
-36
lines changed

cuda_core/cuda/core/_kernel_arg_handler.pyx

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@ from libcpp cimport vector
1313

1414
import ctypes
1515

16-
# this might be an unnecessary assumption that NumPy does not exist...
17-
try:
18-
import numpy
19-
except ImportError:
20-
numpy = None
16+
import numpy
2117

2218
from cuda.core._memory import Buffer
2319

@@ -43,7 +39,32 @@ ctypedef fused supported_type:
4339
cpp_double_complex
4440

4541

46-
# TODO: cache ctypes/numpy type objects to avoid attribute access
42+
# cache ctypes/numpy type objects to avoid attribute access
43+
cdef object ctypes_bool = ctypes.c_bool
44+
cdef object ctypes_int8 = ctypes.c_int8
45+
cdef object ctypes_int16 = ctypes.c_int16
46+
cdef object ctypes_int32 = ctypes.c_int32
47+
cdef object ctypes_int64 = ctypes.c_int64
48+
cdef object ctypes_uint8 = ctypes.c_uint8
49+
cdef object ctypes_uint16 = ctypes.c_uint16
50+
cdef object ctypes_uint32 = ctypes.c_uint32
51+
cdef object ctypes_uint64 = ctypes.c_uint64
52+
cdef object ctypes_float = ctypes.c_float
53+
cdef object ctypes_double = ctypes.c_double
54+
cdef object numpy_bool = numpy.bool_
55+
cdef object numpy_int8 = numpy.int8
56+
cdef object numpy_int16 = numpy.int16
57+
cdef object numpy_int32 = numpy.int32
58+
cdef object numpy_int64 = numpy.int64
59+
cdef object numpy_uint8 = numpy.uint8
60+
cdef object numpy_uint16 = numpy.uint16
61+
cdef object numpy_uint32 = numpy.uint32
62+
cdef object numpy_uint64 = numpy.uint64
63+
cdef object numpy_float16 = numpy.float16
64+
cdef object numpy_float32 = numpy.float32
65+
cdef object numpy_float64 = numpy.float64
66+
cdef object numpy_complex64 = numpy.complex64
67+
cdef object numpy_complex128 = numpy.complex128
4768

4869

4970
# limitation due to cython/cython#534
@@ -76,27 +97,27 @@ cdef inline int prepare_ctypes_arg(
7697
vector.vector[void*]& data_addresses,
7798
arg,
7899
const size_t idx) except -1:
79-
if isinstance(arg, ctypes.c_bool):
100+
if isinstance(arg, ctypes_bool):
80101
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
81-
elif isinstance(arg, ctypes.c_int8):
102+
elif isinstance(arg, ctypes_int8):
82103
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
83-
elif isinstance(arg, ctypes.c_int16):
104+
elif isinstance(arg, ctypes_int16):
84105
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
85-
elif isinstance(arg, ctypes.c_int32):
106+
elif isinstance(arg, ctypes_int32):
86107
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
87-
elif isinstance(arg, ctypes.c_int64):
108+
elif isinstance(arg, ctypes_int64):
88109
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
89-
elif isinstance(arg, ctypes.c_uint8):
110+
elif isinstance(arg, ctypes_uint8):
90111
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
91-
elif isinstance(arg, ctypes.c_uint16):
112+
elif isinstance(arg, ctypes_uint16):
92113
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
93-
elif isinstance(arg, ctypes.c_uint32):
114+
elif isinstance(arg, ctypes_uint32):
94115
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
95-
elif isinstance(arg, ctypes.c_uint64):
116+
elif isinstance(arg, ctypes_uint64):
96117
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
97-
elif isinstance(arg, ctypes.c_float):
118+
elif isinstance(arg, ctypes_float):
98119
return prepare_arg[float](data, data_addresses, arg.value, idx)
99-
elif isinstance(arg, ctypes.c_double):
120+
elif isinstance(arg, ctypes_double):
100121
return prepare_arg[double](data, data_addresses, arg.value, idx)
101122
else:
102123
return 1
@@ -107,37 +128,34 @@ cdef inline int prepare_numpy_arg(
107128
vector.vector[void*]& data_addresses,
108129
arg,
109130
const size_t idx) except -1:
110-
if not numpy:
111-
return 1
112-
113-
if isinstance(arg, numpy.bool_):
131+
if isinstance(arg, numpy_bool):
114132
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
115-
elif isinstance(arg, numpy.int8):
133+
elif isinstance(arg, numpy_int8):
116134
return prepare_arg[int8_t](data, data_addresses, arg, idx)
117-
elif isinstance(arg, numpy.int16):
135+
elif isinstance(arg, numpy_int16):
118136
return prepare_arg[int16_t](data, data_addresses, arg, idx)
119-
elif isinstance(arg, numpy.int32):
137+
elif isinstance(arg, numpy_int32):
120138
return prepare_arg[int32_t](data, data_addresses, arg, idx)
121-
elif isinstance(arg, numpy.int64):
139+
elif isinstance(arg, numpy_int64):
122140
return prepare_arg[int64_t](data, data_addresses, arg, idx)
123-
elif isinstance(arg, numpy.uint8):
141+
elif isinstance(arg, numpy_uint8):
124142
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
125-
elif isinstance(arg, numpy.uint16):
143+
elif isinstance(arg, numpy_uint16):
126144
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
127-
elif isinstance(arg, numpy.uint32):
145+
elif isinstance(arg, numpy_uint32):
128146
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
129-
elif isinstance(arg, numpy.uint64):
147+
elif isinstance(arg, numpy_uint64):
130148
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
131-
elif isinstance(arg, numpy.float16):
149+
elif isinstance(arg, numpy_float16):
132150
# use int16 as a proxy
133151
return prepare_arg[int16_t](data, data_addresses, arg, idx)
134-
elif isinstance(arg, numpy.float32):
152+
elif isinstance(arg, numpy_float32):
135153
return prepare_arg[float](data, data_addresses, arg, idx)
136-
elif isinstance(arg, numpy.float64):
154+
elif isinstance(arg, numpy_float64):
137155
return prepare_arg[double](data, data_addresses, arg, idx)
138-
elif isinstance(arg, numpy.complex64):
156+
elif isinstance(arg, numpy_complex64):
139157
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
140-
elif isinstance(arg, numpy.complex128):
158+
elif isinstance(arg, numpy_complex128):
141159
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
142160
else:
143161
return 1
@@ -185,9 +203,9 @@ cdef class ParamHolder:
185203
continue
186204

187205
not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
188-
if not_prepared != 0:
206+
if not_prepared:
189207
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
190-
if not_prepared != 0:
208+
if not_prepared:
191209
# TODO: support ctypes/numpy struct
192210
raise TypeError
193211

cuda_core/cuda/core/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
2+
#
3+
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4+
15
from cuda.core._memoryview import GPUMemoryView, viewable

cuda_core/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ classifiers = [
4141
"Environment :: GPU :: NVIDIA CUDA :: 11",
4242
"Environment :: GPU :: NVIDIA CUDA :: 12",
4343
]
44+
dependencies = [
45+
"numpy",
46+
]
4447

4548

4649
[tool.setuptools]

0 commit comments

Comments
 (0)