Skip to content

Commit 93b495f

Browse files
committed
update nvjitlink test
1 parent fdc76e8 commit 93b495f

File tree

1 file changed

+76
-36
lines changed

1 file changed

+76
-36
lines changed

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 76 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44

55
import pytest
66

7-
from cuda.bindings import nvjitlink
7+
from cuda.bindings import nvjitlink, nvrtc
88

9-
ptx_kernel = """
10-
.version 8.5
11-
.target sm_90
9+
# Establish a handful of compatible architectures and PTX versions to test with
10+
ARCHITECTURES = ["sm_60", "sm_75", "sm_80", "sm_90"]
11+
PTX_VERSIONS = ["5.0", "6.4", "7.0", "8.5"]
12+
13+
14+
def ptx_header(version, arch):
15+
return f"""
16+
.version {version}
17+
.target {arch}
1218
.address_size 64
19+
"""
20+
1321

22+
ptx_kernel = """
1423
.visible .entry _Z6kernelPi(
1524
.param .u64 _Z6kernelPi_param_0
1625
)
@@ -28,18 +37,40 @@
2837
"""
2938

3039
minimal_ptx_kernel = """
31-
.version 8.5
32-
.target sm_90
33-
.address_size 64
34-
3540
.func _MinimalKernel()
3641
{
3742
ret;
3843
}
3944
"""
4045

41-
ptx_kernel_bytes = ptx_kernel.encode("utf-8")
42-
minimal_ptx_kernel_bytes = minimal_ptx_kernel.encode("utf-8")
46+
ptx_kernel_bytes = [
47+
(ptx_header(version, arch) + ptx_kernel).encode("utf-8") for version, arch in zip(PTX_VERSIONS, ARCHITECTURES)
48+
]
49+
minimal_ptx_kernel_bytes = [
50+
(ptx_header(version, arch) + minimal_ptx_kernel).encode("utf-8")
51+
for version, arch in zip(PTX_VERSIONS, ARCHITECTURES)
52+
]
53+
54+
55+
# create a valid LTOIR input for testing
56+
@pytest.fixture
57+
def get_dummy_ltoir():
58+
def CHECK_NVRTC(err):
59+
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
60+
raise RuntimeError(f"Nvrtc Error: {err}")
61+
62+
empty_cplusplus_kernel = "__global__ void A() {}"
63+
err, program_handle = nvrtc.nvrtcCreateProgram(empty_cplusplus_kernel.encode(), b"", 0, [], [])
64+
CHECK_NVRTC(err)
65+
nvrtc.nvrtcCompileProgram(program_handle, 1, [b"-dlto"])
66+
err, size = nvrtc.nvrtcGetLTOIRSize(program_handle)
67+
CHECK_NVRTC(err)
68+
empty_kernel_ltoir = b" " * size
69+
(err,) = nvrtc.nvrtcGetLTOIR(program_handle, empty_kernel_ltoir)
70+
CHECK_NVRTC(err)
71+
(err,) = nvrtc.nvrtcDestroyProgram(program_handle)
72+
CHECK_NVRTC(err)
73+
return empty_kernel_ltoir
4374

4475

4576
def test_unrecognized_option_error():
@@ -52,39 +83,41 @@ def test_invalid_arch_error():
5283
nvjitlink.create(1, ["-arch=sm_XX"])
5384

5485

55-
def test_create_and_destroy():
56-
handle = nvjitlink.create(1, ["-arch=sm_53"])
86+
@pytest.mark.parametrize("option", ARCHITECTURES)
87+
def test_create_and_destroy(option):
88+
handle = nvjitlink.create(1, [f"-arch={option}"])
5789
assert handle != 0
5890
nvjitlink.destroy(handle)
5991

6092

61-
def test_complete_empty():
62-
handle = nvjitlink.create(1, ["-arch=sm_90"])
93+
@pytest.mark.parametrize("option", ARCHITECTURES)
94+
def test_complete_empty(option):
95+
handle = nvjitlink.create(1, [f"-arch={option}"])
6396
nvjitlink.complete(handle)
6497
nvjitlink.destroy(handle)
6598

6699

67-
def test_add_data():
68-
handle = nvjitlink.create(1, ["-arch=sm_90"])
69-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_kernel_bytes, len(ptx_kernel_bytes), "test_data")
70-
nvjitlink.add_data(
71-
handle, nvjitlink.InputType.ANY, minimal_ptx_kernel_bytes, len(minimal_ptx_kernel_bytes), "minimal_test_data"
72-
)
100+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
101+
def test_add_data(option, ptx_bytes):
102+
handle = nvjitlink.create(1, [f"-arch={option}"])
103+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
73104
nvjitlink.complete(handle)
74105
nvjitlink.destroy(handle)
75106

76107

77-
def test_add_file(tmp_path):
78-
handle = nvjitlink.create(1, ["-arch=sm_90"])
108+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
109+
def test_add_file(option, ptx_bytes, tmp_path):
110+
handle = nvjitlink.create(1, [f"-arch={option}"])
79111
file_path = tmp_path / "test_file.cubin"
80-
file_path.write_bytes(ptx_kernel_bytes)
112+
file_path.write_bytes(ptx_bytes)
81113
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
82114
nvjitlink.complete(handle)
83115
nvjitlink.destroy(handle)
84116

85117

86-
def test_get_error_log():
87-
handle = nvjitlink.create(1, ["-arch=sm_90"])
118+
@pytest.mark.parametrize("option", ARCHITECTURES)
119+
def test_get_error_log(option):
120+
handle = nvjitlink.create(1, [f"-arch={option}"])
88121
nvjitlink.complete(handle)
89122
log_size = nvjitlink.get_error_log_size(handle)
90123
log = bytearray(log_size)
@@ -93,9 +126,10 @@ def test_get_error_log():
93126
nvjitlink.destroy(handle)
94127

95128

96-
def test_get_info_log():
97-
handle = nvjitlink.create(1, ["-arch=sm_90"])
98-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_kernel_bytes, len(ptx_kernel_bytes), "test_data")
129+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
130+
def test_get_info_log(option, ptx_bytes):
131+
handle = nvjitlink.create(1, [f"-arch={option}"])
132+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
99133
nvjitlink.complete(handle)
100134
log_size = nvjitlink.get_info_log_size(handle)
101135
log = bytearray(log_size)
@@ -104,9 +138,10 @@ def test_get_info_log():
104138
nvjitlink.destroy(handle)
105139

106140

107-
def test_get_linked_cubin():
108-
handle = nvjitlink.create(1, ["-arch=sm_90"])
109-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_kernel_bytes, len(ptx_kernel_bytes), "test_data")
141+
@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes))
142+
def test_get_linked_cubin(option, ptx_bytes):
143+
handle = nvjitlink.create(1, [f"-arch={option}"])
144+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
110145
nvjitlink.complete(handle)
111146
cubin_size = nvjitlink.get_linked_cubin_size(handle)
112147
cubin = bytearray(cubin_size)
@@ -115,11 +150,16 @@ def test_get_linked_cubin():
115150
nvjitlink.destroy(handle)
116151

117152

118-
def test_get_linked_ptx():
119-
# TODO improve this test to call get_linked_ptx without this error
120-
handle = nvjitlink.create(2, ["-arch=sm_90", "-lto"])
121-
with pytest.raises(nvjitlink.nvJitLinkError, match="ERROR_NVVM_COMPILE"):
122-
nvjitlink.complete(handle)
153+
@pytest.mark.parametrize("option", ARCHITECTURES)
154+
def test_get_linked_ptx(option, get_dummy_ltoir):
155+
handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"])
156+
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
157+
nvjitlink.complete(handle)
158+
ptx_size = nvjitlink.get_linked_ptx_size(handle)
159+
ptx = bytearray(ptx_size)
160+
nvjitlink.get_linked_ptx(handle, ptx)
161+
assert len(ptx) == ptx_size
162+
nvjitlink.destroy(handle)
123163

124164

125165
def test_package_version():

0 commit comments

Comments
 (0)