Skip to content

Commit 0cb12f1

Browse files
committed
update test
1 parent d4bd29c commit 0cb12f1

File tree

1 file changed

+84
-134
lines changed

1 file changed

+84
-134
lines changed

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 84 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,166 +1,116 @@
11
import pytest
22
from cuda.bindings import nvjitlink
33

4-
dir(nvjitlink)
5-
6-
def test_create_no_arch_error():
7-
# nvjitlink expects at least the architecture to be specified.
8-
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_MISSING_ARCH error"):
9-
nvjitlink.create()
10-
11-
12-
def test_invalid_arch_error():
13-
# sm_XX is not a valid architecture
14-
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_UNRECOGNIZED_OPTION error"):
15-
nvjitlink.create("-arch=sm_XX")
16-
17-
18-
def test_unrecognized_option_error():
19-
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_UNRECOGNIZED_OPTION error"):
20-
nvjitlink.create("-fictitious_option")
21-
4+
import pytest
5+
from cuda.bindings import nvjitlink
226

23-
def test_invalid_option_type_error():
24-
with pytest.raises(TypeError, match="Expecting only strings"):
25-
nvjitlink.create("-arch", 53)
7+
ptx_code = """
8+
.version 8.5
9+
.target sm_90
10+
.address_size 64
2611
12+
.visible .entry _Z6kernelPi(
13+
.param .u64 _Z6kernelPi_param_0
14+
)
15+
{
16+
.reg .pred %p<2>;
17+
.reg .b32 %r<3>;
18+
.reg .b64 %rd<3>;
19+
20+
ld.param.u64 %rd1, [_Z6kernelPi_param_0];
21+
cvta.to.global.u64 %rd2, %rd1;
22+
mov.u32 %r1, %tid.x;
23+
st.global.u32 [%rd2+0], %r1;
24+
ret;
25+
}
26+
"""
27+
28+
minimal_kernel = """
29+
.version 6.4
30+
.target sm_75
31+
.address_size 64
32+
33+
.visible .entry _kernel() {
34+
ret;
35+
}
36+
"""
37+
38+
# Convert PTX code to bytes
39+
ptx_bytes = ptx_code.encode('utf-8')
40+
minimal_kernel_bytes = minimal_kernel.encode('utf-8')
2741

2842
def test_create_and_destroy():
29-
handle = nvjitlink.create("-arch=sm_53")
43+
handle = nvjitlink.create(1, ["-arch=sm_53"])
3044
assert handle != 0
3145
nvjitlink.destroy(handle)
3246

33-
3447
def test_complete_empty():
35-
handle = nvjitlink.create("-arch=sm_75")
48+
handle = nvjitlink.create(1, ["-arch=sm_90"])
3649
nvjitlink.complete(handle)
3750
nvjitlink.destroy(handle)
3851

39-
40-
@pytest.mark.parametrize(
41-
"input_file,input_type",
42-
[
43-
("device_functions_cubin", nvjitlink.InputType.CUBIN),
44-
("device_functions_fatbin", InputType.FATBIN),
45-
("device_functions_ptx", InputType.PTX),
46-
("device_functions_object", InputType.OBJECT),
47-
("device_functions_archive", InputType.LIBRARY),
48-
],
49-
)
50-
def test_add_file(input_file, input_type, gpu_arch_flag, request):
51-
filename, data = request.getfixturevalue(input_file)
52-
53-
handle = nvjitlink.create(gpu_arch_flag)
54-
nvjitlink.add_data(handle, input_type.value, data, filename)
55-
nvjitlink.destroy(handle)
56-
57-
58-
# We test the LTO input case separately as it requires the `-lto` flag. The
59-
# OBJECT input type is used because the LTO-IR container is packaged in an ELF
60-
# object when produced by NVCC.
61-
def test_add_file_lto(device_functions_ltoir_object, gpu_arch_flag):
62-
filename, data = device_functions_ltoir_object
63-
64-
handle = nvjitlink.create(gpu_arch_flag, "-lto")
65-
nvjitlink.add_data(handle, InputType.OBJECT.value, data, filename)
66-
nvjitlink.destroy(handle)
67-
68-
69-
def test_get_error_log(undefined_extern_cubin, gpu_arch_flag):
70-
handle = nvjitlink.create(gpu_arch_flag)
71-
filename, data = undefined_extern_cubin
72-
input_type = InputType.CUBIN.value
73-
nvjitlink.add_data(handle, input_type, data, filename)
74-
with pytest.raises(RuntimeError):
75-
nvjitlink.complete(handle)
76-
error_log = nvjitlink.get_error_log(handle)
52+
def test_add_data():
53+
handle = nvjitlink.create(1, ["-arch=sm_90"])
54+
data = ptx_bytes
55+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, data, len(data), "test_data")
56+
nvjitlink.complete(handle)
7757
nvjitlink.destroy(handle)
78-
assert (
79-
"Undefined reference to '_Z5undefff' "
80-
"in 'undefined_extern.cubin'" in error_log
81-
)
8258

59+
def test_add_file():
60+
handle = nvjitlink.create(1, ["-arch=sm_90"])
61+
file_path = "test_file.cubin"
62+
with open (file_path, "wb") as f:
63+
f.write(ptx_bytes)
8364

84-
def test_get_info_log(device_functions_cubin, gpu_arch_flag):
85-
handle = nvjitlink.create(gpu_arch_flag)
86-
filename, data = device_functions_cubin
87-
input_type = InputType.CUBIN.value
88-
nvjitlink.add_data(handle, input_type, data, filename)
65+
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
8966
nvjitlink.complete(handle)
90-
info_log = nvjitlink.get_info_log(handle)
9167
nvjitlink.destroy(handle)
92-
# Info log is empty
93-
assert "" == info_log
9468

95-
96-
def test_get_linked_cubin(device_functions_cubin, gpu_arch_flag):
97-
handle = nvjitlink.create(gpu_arch_flag)
98-
filename, data = device_functions_cubin
99-
input_type = InputType.CUBIN.value
100-
nvjitlink.add_data(handle, input_type, data, filename)
69+
def test_get_error_log():
70+
handle = nvjitlink.create(1, ["-arch=sm_90"])
10171
nvjitlink.complete(handle)
102-
cubin = nvjitlink.get_linked_cubin(handle)
72+
log_size = nvjitlink.get_error_log_size(handle)
73+
log = nvjitlink.get_error_log(handle)
74+
assert len(log) == log_size
10375
nvjitlink.destroy(handle)
10476

105-
# Just check we got something that looks like an ELF
106-
assert cubin[:4] == b"\x7fELF"
107-
108-
109-
def test_get_linked_cubin_link_not_complete_error(
110-
device_functions_cubin, gpu_arch_flag
111-
):
112-
handle = nvjitlink.create(gpu_arch_flag)
113-
filename, data = device_functions_cubin
114-
input_type = InputType.CUBIN.value
115-
nvjitlink.add_data(handle, input_type, data, filename)
116-
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_INTERNAL error"):
117-
nvjitlink.get_linked_cubin(handle)
77+
def test_get_info_log():
78+
handle = nvjitlink.create(1, ["-arch=sm_90"])
79+
nvjitlink.complete(handle)
80+
log_size = nvjitlink.get_info_log_size(handle)
81+
log = nvjitlink.get_info_log(handle)
82+
assert len(log) == log_size
11883
nvjitlink.destroy(handle)
11984

120-
121-
def test_get_linked_cubin_from_lto(device_functions_ltoir_object, gpu_arch_flag):
122-
filename, data = device_functions_ltoir_object
123-
# device_functions_ltoir_object is a host object containing a fatbin
124-
# containing an LTOIR container, because that is what NVCC produces when
125-
# LTO is requested. So we need to use the OBJECT input type, and the linker
126-
# retrieves the LTO IR from it because we passed the -lto flag.
127-
input_type = InputType.OBJECT.value
128-
handle = nvjitlink.create(gpu_arch_flag, "-lto")
129-
nvjitlink.add_data(handle, input_type, data, filename)
85+
def test_get_linked_cubin():
86+
handle = nvjitlink.create(1, ["-arch=sm_90"])
13087
nvjitlink.complete(handle)
88+
cubin_size = nvjitlink.get_linked_cubin_size(handle)
13189
cubin = nvjitlink.get_linked_cubin(handle)
90+
assert len(cubin) == cubin_size
13291
nvjitlink.destroy(handle)
13392

134-
# Just check we got something that looks like an ELF
135-
assert cubin[:4] == b"\x7fELF"
136-
137-
138-
def test_get_linked_ptx_from_lto(device_functions_ltoir_object, gpu_arch_flag):
139-
filename, data = device_functions_ltoir_object
140-
# device_functions_ltoir_object is a host object containing a fatbin
141-
# containing an LTOIR container, because that is what NVCC produces when
142-
# LTO is requested. So we need to use the OBJECT input type, and the linker
143-
# retrieves the LTO IR from it because we passed the -lto flag.
144-
input_type = InputType.OBJECT.value
145-
handle = nvjitlink.create(gpu_arch_flag, "-lto", "-ptx")
146-
nvjitlink.add_data(handle, input_type, data, filename)
93+
def test_get_linked_ptx():
94+
handle = nvjitlink.create(2, ["-arch=sm_90", "-lto"])
95+
data = minimal_kernel_bytes
96+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, data, len(data), "test_data")
14797
nvjitlink.complete(handle)
148-
nvjitlink.get_linked_ptx(handle)
98+
ptx_size = nvjitlink.get_linked_ptx_size(handle)
99+
ptx = nvjitlink.get_linked_ptx(handle)
100+
assert len(ptx) == ptx_size
149101
nvjitlink.destroy(handle)
150102

151-
152-
def test_get_linked_ptx_link_not_complete_error(
153-
device_functions_ltoir_object, gpu_arch_flag
154-
):
155-
handle = nvjitlink.create(gpu_arch_flag, "-lto", "-ptx")
156-
filename, data = device_functions_ltoir_object
157-
input_type = InputType.OBJECT.value
158-
nvjitlink.add_data(handle, input_type, data, filename)
159-
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_INTERNAL error"):
160-
nvjitlink.get_linked_ptx(handle)
161-
nvjitlink.destroy(handle)
162-
163-
164-
def test_package_version():
165-
assert pynvjitlink.__version__ is not None
166-
assert len(str(pynvjitlink.__version__)) > 0
103+
def test_version():
104+
major, minor = nvjitlink.version()
105+
assert major >= 0
106+
assert minor >= 0
107+
108+
test_create_and_destroy()
109+
test_complete_empty()
110+
test_add_data()
111+
test_add_file()
112+
test_get_error_log()
113+
test_get_info_log()
114+
test_get_linked_cubin()
115+
test_get_linked_ptx()
116+
test_version()

0 commit comments

Comments
 (0)