1
1
import pytest
2
2
from cuda .bindings import nvjitlink
3
3
4
- dir (nvjitlink )
5
-
6
- def test_create_no_arch_error ():
7
- # nvjitlink expects at least the architecture to be specified.
8
- with pytest .raises (RuntimeError , match = "NVJITLINK_ERROR_MISSING_ARCH error" ):
9
- nvjitlink .create ()
10
-
11
-
12
- def test_invalid_arch_error ():
13
- # sm_XX is not a valid architecture
14
- with pytest .raises (RuntimeError , match = "NVJITLINK_ERROR_UNRECOGNIZED_OPTION error" ):
15
- nvjitlink .create ("-arch=sm_XX" )
16
-
17
-
18
- def test_unrecognized_option_error ():
19
- with pytest .raises (RuntimeError , match = "NVJITLINK_ERROR_UNRECOGNIZED_OPTION error" ):
20
- nvjitlink .create ("-fictitious_option" )
21
-
4
+ import pytest
5
+ from cuda .bindings import nvjitlink
22
6
23
- def test_invalid_option_type_error ():
24
- with pytest .raises (TypeError , match = "Expecting only strings" ):
25
- nvjitlink .create ("-arch" , 53 )
7
+ ptx_code = """
8
+ .version 8.5
9
+ .target sm_90
10
+ .address_size 64
26
11
12
+ .visible .entry _Z6kernelPi(
13
+ .param .u64 _Z6kernelPi_param_0
14
+ )
15
+ {
16
+ .reg .pred %p<2>;
17
+ .reg .b32 %r<3>;
18
+ .reg .b64 %rd<3>;
19
+
20
+ ld.param.u64 %rd1, [_Z6kernelPi_param_0];
21
+ cvta.to.global.u64 %rd2, %rd1;
22
+ mov.u32 %r1, %tid.x;
23
+ st.global.u32 [%rd2+0], %r1;
24
+ ret;
25
+ }
26
+ """
27
+
28
+ minimal_kernel = """
29
+ .version 6.4
30
+ .target sm_75
31
+ .address_size 64
32
+
33
+ .visible .entry _kernel() {
34
+ ret;
35
+ }
36
+ """
37
+
38
+ # Convert PTX code to bytes
39
+ ptx_bytes = ptx_code .encode ('utf-8' )
40
+ minimal_kernel_bytes = minimal_kernel .encode ('utf-8' )
27
41
28
42
def test_create_and_destroy ():
29
- handle = nvjitlink .create ("-arch=sm_53" )
43
+ handle = nvjitlink .create (1 , [ "-arch=sm_53" ] )
30
44
assert handle != 0
31
45
nvjitlink .destroy (handle )
32
46
33
-
34
47
def test_complete_empty ():
35
- handle = nvjitlink .create ("-arch=sm_75" )
48
+ handle = nvjitlink .create (1 , [ "-arch=sm_90" ] )
36
49
nvjitlink .complete (handle )
37
50
nvjitlink .destroy (handle )
38
51
39
-
40
- @pytest .mark .parametrize (
41
- "input_file,input_type" ,
42
- [
43
- ("device_functions_cubin" , nvjitlink .InputType .CUBIN ),
44
- ("device_functions_fatbin" , InputType .FATBIN ),
45
- ("device_functions_ptx" , InputType .PTX ),
46
- ("device_functions_object" , InputType .OBJECT ),
47
- ("device_functions_archive" , InputType .LIBRARY ),
48
- ],
49
- )
50
- def test_add_file (input_file , input_type , gpu_arch_flag , request ):
51
- filename , data = request .getfixturevalue (input_file )
52
-
53
- handle = nvjitlink .create (gpu_arch_flag )
54
- nvjitlink .add_data (handle , input_type .value , data , filename )
55
- nvjitlink .destroy (handle )
56
-
57
-
58
- # We test the LTO input case separately as it requires the `-lto` flag. The
59
- # OBJECT input type is used because the LTO-IR container is packaged in an ELF
60
- # object when produced by NVCC.
61
- def test_add_file_lto (device_functions_ltoir_object , gpu_arch_flag ):
62
- filename , data = device_functions_ltoir_object
63
-
64
- handle = nvjitlink .create (gpu_arch_flag , "-lto" )
65
- nvjitlink .add_data (handle , InputType .OBJECT .value , data , filename )
66
- nvjitlink .destroy (handle )
67
-
68
-
69
- def test_get_error_log (undefined_extern_cubin , gpu_arch_flag ):
70
- handle = nvjitlink .create (gpu_arch_flag )
71
- filename , data = undefined_extern_cubin
72
- input_type = InputType .CUBIN .value
73
- nvjitlink .add_data (handle , input_type , data , filename )
74
- with pytest .raises (RuntimeError ):
75
- nvjitlink .complete (handle )
76
- error_log = nvjitlink .get_error_log (handle )
52
+ def test_add_data ():
53
+ handle = nvjitlink .create (1 , ["-arch=sm_90" ])
54
+ data = ptx_bytes
55
+ nvjitlink .add_data (handle , nvjitlink .InputType .ANY , data , len (data ), "test_data" )
56
+ nvjitlink .complete (handle )
77
57
nvjitlink .destroy (handle )
78
- assert (
79
- "Undefined reference to '_Z5undefff' "
80
- "in 'undefined_extern.cubin'" in error_log
81
- )
82
58
59
+ def test_add_file ():
60
+ handle = nvjitlink .create (1 , ["-arch=sm_90" ])
61
+ file_path = "test_file.cubin"
62
+ with open (file_path , "wb" ) as f :
63
+ f .write (ptx_bytes )
83
64
84
- def test_get_info_log (device_functions_cubin , gpu_arch_flag ):
85
- handle = nvjitlink .create (gpu_arch_flag )
86
- filename , data = device_functions_cubin
87
- input_type = InputType .CUBIN .value
88
- nvjitlink .add_data (handle , input_type , data , filename )
65
+ nvjitlink .add_file (handle , nvjitlink .InputType .ANY , str (file_path ))
89
66
nvjitlink .complete (handle )
90
- info_log = nvjitlink .get_info_log (handle )
91
67
nvjitlink .destroy (handle )
92
- # Info log is empty
93
- assert "" == info_log
94
68
95
-
96
- def test_get_linked_cubin (device_functions_cubin , gpu_arch_flag ):
97
- handle = nvjitlink .create (gpu_arch_flag )
98
- filename , data = device_functions_cubin
99
- input_type = InputType .CUBIN .value
100
- nvjitlink .add_data (handle , input_type , data , filename )
69
+ def test_get_error_log ():
70
+ handle = nvjitlink .create (1 , ["-arch=sm_90" ])
101
71
nvjitlink .complete (handle )
102
- cubin = nvjitlink .get_linked_cubin (handle )
72
+ log_size = nvjitlink .get_error_log_size (handle )
73
+ log = nvjitlink .get_error_log (handle )
74
+ assert len (log ) == log_size
103
75
nvjitlink .destroy (handle )
104
76
105
- # Just check we got something that looks like an ELF
106
- assert cubin [:4 ] == b"\x7f ELF"
107
-
108
-
109
- def test_get_linked_cubin_link_not_complete_error (
110
- device_functions_cubin , gpu_arch_flag
111
- ):
112
- handle = nvjitlink .create (gpu_arch_flag )
113
- filename , data = device_functions_cubin
114
- input_type = InputType .CUBIN .value
115
- nvjitlink .add_data (handle , input_type , data , filename )
116
- with pytest .raises (RuntimeError , match = "NVJITLINK_ERROR_INTERNAL error" ):
117
- nvjitlink .get_linked_cubin (handle )
77
+ def test_get_info_log ():
78
+ handle = nvjitlink .create (1 , ["-arch=sm_90" ])
79
+ nvjitlink .complete (handle )
80
+ log_size = nvjitlink .get_info_log_size (handle )
81
+ log = nvjitlink .get_info_log (handle )
82
+ assert len (log ) == log_size
118
83
nvjitlink .destroy (handle )
119
84
120
-
121
- def test_get_linked_cubin_from_lto (device_functions_ltoir_object , gpu_arch_flag ):
122
- filename , data = device_functions_ltoir_object
123
- # device_functions_ltoir_object is a host object containing a fatbin
124
- # containing an LTOIR container, because that is what NVCC produces when
125
- # LTO is requested. So we need to use the OBJECT input type, and the linker
126
- # retrieves the LTO IR from it because we passed the -lto flag.
127
- input_type = InputType .OBJECT .value
128
- handle = nvjitlink .create (gpu_arch_flag , "-lto" )
129
- nvjitlink .add_data (handle , input_type , data , filename )
85
+ def test_get_linked_cubin ():
86
+ handle = nvjitlink .create (1 , ["-arch=sm_90" ])
130
87
nvjitlink .complete (handle )
88
+ cubin_size = nvjitlink .get_linked_cubin_size (handle )
131
89
cubin = nvjitlink .get_linked_cubin (handle )
90
+ assert len (cubin ) == cubin_size
132
91
nvjitlink .destroy (handle )
133
92
134
- # Just check we got something that looks like an ELF
135
- assert cubin [:4 ] == b"\x7f ELF"
136
-
137
-
138
- def test_get_linked_ptx_from_lto (device_functions_ltoir_object , gpu_arch_flag ):
139
- filename , data = device_functions_ltoir_object
140
- # device_functions_ltoir_object is a host object containing a fatbin
141
- # containing an LTOIR container, because that is what NVCC produces when
142
- # LTO is requested. So we need to use the OBJECT input type, and the linker
143
- # retrieves the LTO IR from it because we passed the -lto flag.
144
- input_type = InputType .OBJECT .value
145
- handle = nvjitlink .create (gpu_arch_flag , "-lto" , "-ptx" )
146
- nvjitlink .add_data (handle , input_type , data , filename )
93
+ def test_get_linked_ptx ():
94
+ handle = nvjitlink .create (2 , ["-arch=sm_90" , "-lto" ])
95
+ data = minimal_kernel_bytes
96
+ nvjitlink .add_data (handle , nvjitlink .InputType .ANY , data , len (data ), "test_data" )
147
97
nvjitlink .complete (handle )
148
- nvjitlink .get_linked_ptx (handle )
98
+ ptx_size = nvjitlink .get_linked_ptx_size (handle )
99
+ ptx = nvjitlink .get_linked_ptx (handle )
100
+ assert len (ptx ) == ptx_size
149
101
nvjitlink .destroy (handle )
150
102
151
-
152
- def test_get_linked_ptx_link_not_complete_error (
153
- device_functions_ltoir_object , gpu_arch_flag
154
- ):
155
- handle = nvjitlink .create (gpu_arch_flag , "-lto" , "-ptx" )
156
- filename , data = device_functions_ltoir_object
157
- input_type = InputType .OBJECT .value
158
- nvjitlink .add_data (handle , input_type , data , filename )
159
- with pytest .raises (RuntimeError , match = "NVJITLINK_ERROR_INTERNAL error" ):
160
- nvjitlink .get_linked_ptx (handle )
161
- nvjitlink .destroy (handle )
162
-
163
-
164
- def test_package_version ():
165
- assert pynvjitlink .__version__ is not None
166
- assert len (str (pynvjitlink .__version__ )) > 0
103
+ def test_version ():
104
+ major , minor = nvjitlink .version ()
105
+ assert major >= 0
106
+ assert minor >= 0
107
+
108
+ test_create_and_destroy ()
109
+ test_complete_empty ()
110
+ test_add_data ()
111
+ test_add_file ()
112
+ test_get_error_log ()
113
+ test_get_info_log ()
114
+ test_get_linked_cubin ()
115
+ test_get_linked_ptx ()
116
+ test_version ()
0 commit comments