4
4
5
5
import pytest
6
6
7
- from cuda .bindings import nvjitlink
7
+ from cuda .bindings import nvjitlink , nvrtc
8
8
9
- ptx_kernel = """
10
- .version 8.5
11
- .target sm_90
9
+ # Establish a handful of compatible architectures and PTX versions to test with
10
+ ARCHITECTURES = ["sm_60" , "sm_75" , "sm_80" , "sm_90" ]
11
+ PTX_VERSIONS = ["5.0" , "6.4" , "7.0" , "8.5" ]
12
+
13
+
14
+ def ptx_header (version , arch ):
15
+ return f"""
16
+ .version { version }
17
+ .target { arch }
12
18
.address_size 64
19
+ """
20
+
13
21
22
+ ptx_kernel = """
14
23
.visible .entry _Z6kernelPi(
15
24
.param .u64 _Z6kernelPi_param_0
16
25
)
28
37
"""
29
38
30
39
minimal_ptx_kernel = """
31
- .version 8.5
32
- .target sm_90
33
- .address_size 64
34
-
35
40
.func _MinimalKernel()
36
41
{
37
42
ret;
38
43
}
39
44
"""
40
45
41
- ptx_kernel_bytes = ptx_kernel .encode ("utf-8" )
42
- minimal_ptx_kernel_bytes = minimal_ptx_kernel .encode ("utf-8" )
46
+ ptx_kernel_bytes = [
47
+ (ptx_header (version , arch ) + ptx_kernel ).encode ("utf-8" ) for version , arch in zip (PTX_VERSIONS , ARCHITECTURES )
48
+ ]
49
+ minimal_ptx_kernel_bytes = [
50
+ (ptx_header (version , arch ) + minimal_ptx_kernel ).encode ("utf-8" )
51
+ for version , arch in zip (PTX_VERSIONS , ARCHITECTURES )
52
+ ]
53
+
54
+
55
+ # create a valid LTOIR input for testing
56
+ @pytest .fixture
57
+ def get_dummy_ltoir ():
58
+ def CHECK_NVRTC (err ):
59
+ if err != nvrtc .nvrtcResult .NVRTC_SUCCESS :
60
+ raise RuntimeError (f"Nvrtc Error: { err } " )
61
+
62
+ empty_cplusplus_kernel = "__global__ void A() {}"
63
+ err , program_handle = nvrtc .nvrtcCreateProgram (empty_cplusplus_kernel .encode (), b"" , 0 , [], [])
64
+ CHECK_NVRTC (err )
65
+ nvrtc .nvrtcCompileProgram (program_handle , 1 , [b"-dlto" ])
66
+ err , size = nvrtc .nvrtcGetLTOIRSize (program_handle )
67
+ CHECK_NVRTC (err )
68
+ empty_kernel_ltoir = b" " * size
69
+ (err ,) = nvrtc .nvrtcGetLTOIR (program_handle , empty_kernel_ltoir )
70
+ CHECK_NVRTC (err )
71
+ (err ,) = nvrtc .nvrtcDestroyProgram (program_handle )
72
+ CHECK_NVRTC (err )
73
+ return empty_kernel_ltoir
43
74
44
75
45
76
def test_unrecognized_option_error ():
@@ -52,39 +83,41 @@ def test_invalid_arch_error():
52
83
nvjitlink .create (1 , ["-arch=sm_XX" ])
53
84
54
85
55
- def test_create_and_destroy ():
56
- handle = nvjitlink .create (1 , ["-arch=sm_53" ])
86
+ @pytest .mark .parametrize ("option" , ARCHITECTURES )
87
+ def test_create_and_destroy (option ):
88
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
57
89
assert handle != 0
58
90
nvjitlink .destroy (handle )
59
91
60
92
61
- def test_complete_empty ():
62
- handle = nvjitlink .create (1 , ["-arch=sm_90" ])
93
+ @pytest .mark .parametrize ("option" , ARCHITECTURES )
94
+ def test_complete_empty (option ):
95
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
63
96
nvjitlink .complete (handle )
64
97
nvjitlink .destroy (handle )
65
98
66
99
67
- def test_add_data ():
68
- handle = nvjitlink .create (1 , ["-arch=sm_90" ])
69
- nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_kernel_bytes , len (ptx_kernel_bytes ), "test_data" )
70
- nvjitlink .add_data (
71
- handle , nvjitlink .InputType .ANY , minimal_ptx_kernel_bytes , len (minimal_ptx_kernel_bytes ), "minimal_test_data"
72
- )
100
+ @pytest .mark .parametrize ("option, ptx_bytes" , zip (ARCHITECTURES , ptx_kernel_bytes ))
101
+ def test_add_data (option , ptx_bytes ):
102
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
103
+ nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_bytes , len (ptx_bytes ), "test_data" )
73
104
nvjitlink .complete (handle )
74
105
nvjitlink .destroy (handle )
75
106
76
107
77
- def test_add_file (tmp_path ):
78
- handle = nvjitlink .create (1 , ["-arch=sm_90" ])
108
+ @pytest .mark .parametrize ("option, ptx_bytes" , zip (ARCHITECTURES , ptx_kernel_bytes ))
109
+ def test_add_file (option , ptx_bytes , tmp_path ):
110
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
79
111
file_path = tmp_path / "test_file.cubin"
80
- file_path .write_bytes (ptx_kernel_bytes )
112
+ file_path .write_bytes (ptx_bytes )
81
113
nvjitlink .add_file (handle , nvjitlink .InputType .ANY , str (file_path ))
82
114
nvjitlink .complete (handle )
83
115
nvjitlink .destroy (handle )
84
116
85
117
86
- def test_get_error_log ():
87
- handle = nvjitlink .create (1 , ["-arch=sm_90" ])
118
+ @pytest .mark .parametrize ("option" , ARCHITECTURES )
119
+ def test_get_error_log (option ):
120
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
88
121
nvjitlink .complete (handle )
89
122
log_size = nvjitlink .get_error_log_size (handle )
90
123
log = bytearray (log_size )
@@ -93,9 +126,10 @@ def test_get_error_log():
93
126
nvjitlink .destroy (handle )
94
127
95
128
96
- def test_get_info_log ():
97
- handle = nvjitlink .create (1 , ["-arch=sm_90" ])
98
- nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_kernel_bytes , len (ptx_kernel_bytes ), "test_data" )
129
+ @pytest .mark .parametrize ("option, ptx_bytes" , zip (ARCHITECTURES , ptx_kernel_bytes ))
130
+ def test_get_info_log (option , ptx_bytes ):
131
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
132
+ nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_bytes , len (ptx_bytes ), "test_data" )
99
133
nvjitlink .complete (handle )
100
134
log_size = nvjitlink .get_info_log_size (handle )
101
135
log = bytearray (log_size )
@@ -104,9 +138,10 @@ def test_get_info_log():
104
138
nvjitlink .destroy (handle )
105
139
106
140
107
- def test_get_linked_cubin ():
108
- handle = nvjitlink .create (1 , ["-arch=sm_90" ])
109
- nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_kernel_bytes , len (ptx_kernel_bytes ), "test_data" )
141
+ @pytest .mark .parametrize ("option, ptx_bytes" , zip (ARCHITECTURES , ptx_kernel_bytes ))
142
+ def test_get_linked_cubin (option , ptx_bytes ):
143
+ handle = nvjitlink .create (1 , [f"-arch={ option } " ])
144
+ nvjitlink .add_data (handle , nvjitlink .InputType .ANY , ptx_bytes , len (ptx_bytes ), "test_data" )
110
145
nvjitlink .complete (handle )
111
146
cubin_size = nvjitlink .get_linked_cubin_size (handle )
112
147
cubin = bytearray (cubin_size )
@@ -115,11 +150,16 @@ def test_get_linked_cubin():
115
150
nvjitlink .destroy (handle )
116
151
117
152
118
- def test_get_linked_ptx ():
119
- # TODO improve this test to call get_linked_ptx without this error
120
- handle = nvjitlink .create (2 , ["-arch=sm_90" , "-lto" ])
121
- with pytest .raises (nvjitlink .nvJitLinkError , match = "ERROR_NVVM_COMPILE" ):
122
- nvjitlink .complete (handle )
153
+ @pytest .mark .parametrize ("option" , ARCHITECTURES )
154
+ def test_get_linked_ptx (option , get_dummy_ltoir ):
155
+ handle = nvjitlink .create (3 , [f"-arch={ option } " , "-lto" , "-ptx" ])
156
+ nvjitlink .add_data (handle , nvjitlink .InputType .LTOIR , get_dummy_ltoir , len (get_dummy_ltoir ), "test_data" )
157
+ nvjitlink .complete (handle )
158
+ ptx_size = nvjitlink .get_linked_ptx_size (handle )
159
+ ptx = bytearray (ptx_size )
160
+ nvjitlink .get_linked_ptx (handle , ptx )
161
+ assert len (ptx ) == ptx_size
162
+ nvjitlink .destroy (handle )
123
163
124
164
125
165
def test_package_version ():
0 commit comments