Skip to content

Commit 702fbaa

Browse files
committed
handle culink and nvjitlink differences in the backend and test
1 parent d7bf4cb commit 702fbaa

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

cuda_core/cuda/core/experimental/_linker.py

+19-31
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def _lazy_init():
2929
_driver_ver = handle_return(cuda.cuDriverGetVersion())
3030
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
3131
try:
32-
raise ImportError
3332
from cuda.bindings import nvjitlink
3433
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
3534
except ImportError:
@@ -247,7 +246,7 @@ def _init_nvjitlink(self):
247246
self.formatted_options.append(f"-split-compile={self.split_compile}")
248247
if self.split_compile_extended is not None:
249248
self.formatted_options.append(f"-split-compile-extended={self.split_compile_extended}")
250-
if self.no_cache is not None:
249+
if self.no_cache is True:
251250
self.formatted_options.append("-no-cache")
252251

253252
def _init_driver(self):
@@ -272,57 +271,46 @@ def _init_driver(self):
272271
self.formatted_options.append(self.max_register_count)
273272
self.option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
274273
if self.time is not None:
275-
self.formatted_options.append(1) # ctypes.c_int32(1)
276-
self.option_keys.append(_driver.CUjit_option.CU_JIT_WALL_TIME)
274+
raise ValueError("time option is not supported by the driver API")
277275
if self.verbose is not None:
278-
self.formatted_options.append(1) # ctypes.c_int32(1)
276+
self.formatted_options.append(1)
279277
self.option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
280278
if self.link_time_optimization is not None:
281-
self.formatted_options.append(1) # ctypes.c_int32(1)
279+
self.formatted_options.append(1)
282280
self.option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
283281
if self.ptx is not None:
284-
self.formatted_options.append(1) # ctypes.c_int32(1)
285-
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
282+
raise ValueError("ptx option is not supported by the driver API")
286283
if self.optimization_level is not None:
287284
self.formatted_options.append(self.optimization_level)
288285
self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
289286
if self.debug is not None:
290-
self.formatted_options.append(1) # ctypes.c_int32(1)
287+
self.formatted_options.append(1)
291288
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
292289
if self.lineinfo is not None:
293-
self.formatted_options.append(1) # ctypes.c_int32(1)
290+
self.formatted_options.append(1)
294291
self.option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
295292
if self.ftz is not None:
296-
self.formatted_options.append(1 if self.ftz else 0)
297-
self.option_keys.append(_driver.CUjit_option.CU_JIT_FTZ)
293+
raise ValueError("ftz option is deprecated in the driver API")
298294
if self.prec_div is not None:
299-
self.formatted_options.append(1 if self.prec_div else 0)
300-
self.option_keys.append(_driver.CUjit_option.CU_JIT_PREC_DIV)
295+
raise ValueError("prec_div option is deprecated in the driver API")
301296
if self.prec_sqrt is not None:
302-
self.formatted_options.append(1 if self.prec_sqrt else 0)
303-
self.option_keys.append(_driver.CUjit_option.CU_JIT_PREC_SQRT)
297+
raise ValueError("prec_sqrt option is deprecated in the driver API")
304298
if self.fma is not None:
305-
self.formatted_options.append(1 if self.fma else 0)
306-
self.option_keys.append(_driver.CUjit_option.CU_JIT_FMA)
299+
raise ValueError("fma options is deprecated in the driver API")
307300
if self.kernels_used is not None:
308-
for kernel in self.kernels_used:
309-
self.formatted_options.append(kernel.encode())
310-
self.option_keys.append(_driver.CUjit_option.CU_JIT_REFERENCED_KERNEL_NAMES)
301+
raise ValueError("kernels_used is deprecated in the driver API")
311302
if self.variables_used is not None:
312-
for variable in self.variables_used:
313-
self.formatted_options.append(variable.encode())
314-
self.option_keys.append(_driver.CUjit_option.CU_JIT_REFERENCED_VARIABLE_NAMES)
303+
raise ValueError("variables_used is deprecated in the driver API")
315304
if self.optimize_unused_variables is not None:
316-
self.formatted_options.append(1) # ctypes.c_int32(1)
317-
self.option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZE_UNUSED_DEVICE_VARIABLES)
305+
raise ValueError("optimize_unused_variables is deprecated in the driver API")
318306
if self.xptxas is not None:
319-
for opt in self.xptxas:
320-
raise NotImplementedError("TODO: implement xptxas option")
307+
raise ValueError("xptxas option is not supported by the driver API")
308+
if self.split_compile is not None:
309+
raise ValueError("split_compile option is not supported by the driver API")
321310
if self.split_compile_extended is not None:
322-
self.formatted_options.append(self.split_compile_extended)
323-
self.option_keys.append(_driver.CUjit_option.CU_JIT_MIN_CTA_PER_SM)
311+
raise ValueError("split_compile_extended option is not supported by the driver API")
324312
if self.no_cache is not None:
325-
self.formatted_options.append(1) # ctypes.c_int32(1)
313+
self.formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
326314
self.option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
327315

328316

cuda_core/tests/test_linker.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88
basic_kernel = "__device__ int B() { return 0; }"
99
addition_kernel = "__device__ int C(int a, int b) { return a + b; }"
1010

11+
try:
12+
from cuda.bindings import nvjitlink # noqa F401
13+
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
14+
except ImportError:
15+
# binding is not available
16+
culink_backend = True
17+
else:
18+
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion") == 0:
19+
# binding is available, but nvJitLink is not installed
20+
culink_backend = True
21+
1122

1223
@pytest.fixture(scope="function")
1324
def compile_ptx_functions(init_cuda):
@@ -27,27 +38,36 @@ def compile_ltoir_functions(init_cuda):
2738
return object_code_a_ltoir, object_code_b_ltoir, object_code_c_ltoir
2839

2940

41+
culink_options = [
42+
LinkerOptions(arch=ARCH),
43+
LinkerOptions(arch=ARCH, max_register_count=32),
44+
LinkerOptions(arch=ARCH, verbose=True),
45+
LinkerOptions(arch=ARCH, optimization_level=3),
46+
LinkerOptions(arch=ARCH, debug=True),
47+
LinkerOptions(arch=ARCH, lineinfo=True),
48+
LinkerOptions(arch=ARCH, no_cache=True),
49+
]
50+
51+
3052
@pytest.mark.parametrize(
3153
"options",
32-
[
33-
LinkerOptions(arch=ARCH),
34-
LinkerOptions(arch=ARCH, max_register_count=32),
54+
culink_options
55+
if culink_backend
56+
else culink_options
57+
+ [
3558
LinkerOptions(arch=ARCH, time=True),
36-
LinkerOptions(arch=ARCH, verbose=True),
37-
LinkerOptions(arch=ARCH, optimization_level=3),
38-
LinkerOptions(arch=ARCH, debug=True),
39-
LinkerOptions(arch=ARCH, lineinfo=True),
4059
LinkerOptions(arch=ARCH, ftz=True),
4160
LinkerOptions(arch=ARCH, prec_div=True),
4261
LinkerOptions(arch=ARCH, prec_sqrt=True),
4362
LinkerOptions(arch=ARCH, fma=True),
4463
LinkerOptions(arch=ARCH, kernels_used=["kernel1"]),
64+
LinkerOptions(arch=ARCH, kernels_used=["kernel1", "kernel2"]),
4565
LinkerOptions(arch=ARCH, variables_used=["var1"]),
66+
LinkerOptions(arch=ARCH, variables_used=["var1", "var2"]),
4667
LinkerOptions(arch=ARCH, optimize_unused_variables=True),
47-
# LinkerOptions(arch=ARCH, xptxas=["-v"]),
48-
# LinkerOptions(arch=ARCH, split_compile=0),
68+
LinkerOptions(arch=ARCH, xptxas=["-v"]),
69+
LinkerOptions(arch=ARCH, split_compile=0),
4970
LinkerOptions(arch=ARCH, split_compile_extended=1),
50-
# LinkerOptions(arch=ARCH, no_cache=True),
5171
],
5272
)
5373
def test_linker_init(compile_ptx_functions, options):
@@ -62,11 +82,12 @@ def test_linker_init_invalid_arch():
6282
Linker(options)
6383

6484

65-
# def test_linker_link_ptx(compile_ltoir_functions):
66-
# options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
67-
# linker = Linker(*compile_ltoir_functions, options=options)
68-
# linked_code = linker.link("ptx")
69-
# assert isinstance(linked_code, ObjectCode)
85+
@pytest.mark.skipif(culink_backend, reason="culink does not support ptx option")
86+
def test_linker_link_ptx(compile_ltoir_functions):
87+
options = LinkerOptions(arch=ARCH, link_time_optimization=True, ptx=True)
88+
linker = Linker(*compile_ltoir_functions, options=options)
89+
linked_code = linker.link("ptx")
90+
assert isinstance(linked_code, ObjectCode)
7091

7192

7293
def test_linker_link_cubin(compile_ptx_functions):

0 commit comments

Comments
 (0)