|
6 | 6 | import weakref
|
7 | 7 | from contextlib import contextmanager
|
8 | 8 | from dataclasses import dataclass
|
9 |
| -from typing import List, Optional |
| 9 | +from typing import List, Optional, Tuple, Union |
10 | 10 | from warnings import warn
|
11 | 11 |
|
12 | 12 | from cuda.core.experimental._device import Device
|
13 | 13 | from cuda.core.experimental._module import ObjectCode
|
14 |
| -from cuda.core.experimental._utils import check_or_create_options, driver, handle_return |
| 14 | +from cuda.core.experimental._utils import check_or_create_options, driver, handle_return, is_sequence |
15 | 15 |
|
16 | 16 | # TODO: revisit this treatment for py313t builds
|
17 | 17 | _driver = None # populated if nvJitLink cannot be used
|
@@ -130,15 +130,14 @@ class LinkerOptions:
|
130 | 130 | fma : bool, optional
|
131 | 131 | Use fast multiply-add.
|
132 | 132 | Default: True.
|
133 |
| - kernels_used : List[str], optional |
134 |
| - Pass list of kernels that are used; any not in the list can be removed. This option can be specified multiple |
135 |
| - times. |
136 |
| - variables_used : List[str], optional |
137 |
| - Pass a list of variables that are used; any not in the list can be removed. |
| 133 | + kernels_used : [Union[str, Tuple[str], List[str]]], optional |
| 134 | + Pass a kernel or sequence of kernels that are used; any not in the list can be removed. |
| 135 | + variables_used : [Union[str, Tuple[str], List[str]]], optional |
| 136 | + Pass a variable or sequence of variables that are used; any not in the list can be removed. |
138 | 137 | optimize_unused_variables : bool, optional
|
139 | 138 | Assume that if a variable is not referenced in device code, it can be removed.
|
140 | 139 | Default: False.
|
141 |
| - ptxas_options : List[str], optional |
| 140 | + ptxas_options : [Union[str, Tuple[str], List[str]]], optional |
142 | 141 | Pass options to PTXAS.
|
143 | 142 | split_compile : int, optional
|
144 | 143 | Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split
|
@@ -167,10 +166,10 @@ class LinkerOptions:
|
167 | 166 | prec_div: Optional[bool] = None
|
168 | 167 | prec_sqrt: Optional[bool] = None
|
169 | 168 | fma: Optional[bool] = None
|
170 |
| - kernels_used: Optional[List[str]] = None |
171 |
| - variables_used: Optional[List[str]] = None |
| 169 | + kernels_used: Optional[Union[str, Tuple[str], List[str]]] = None |
| 170 | + variables_used: Optional[Union[str, Tuple[str], List[str]]] = None |
172 | 171 | optimize_unused_variables: Optional[bool] = None
|
173 |
| - ptxas_options: Optional[List[str]] = None |
| 172 | + ptxas_options: Optional[Union[str, Tuple[str], List[str]]] = None |
174 | 173 | split_compile: Optional[int] = None
|
175 | 174 | split_compile_extended: Optional[int] = None
|
176 | 175 | no_cache: Optional[bool] = None
|
@@ -213,16 +212,25 @@ def _init_nvjitlink(self):
|
213 | 212 | if self.fma is not None:
|
214 | 213 | self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}")
|
215 | 214 | if self.kernels_used is not None:
|
216 |
| - for kernel in self.kernels_used: |
217 |
| - self.formatted_options.append(f"-kernels-used={kernel}") |
| 215 | + if isinstance(self.kernels_used, str): |
| 216 | + self.formatted_options.append(f"-kernels-used={self.kernels_used}") |
| 217 | + elif isinstance(self.kernels_used, list): |
| 218 | + for kernel in self.kernels_used: |
| 219 | + self.formatted_options.append(f"-kernels-used={kernel}") |
218 | 220 | if self.variables_used is not None:
|
219 |
| - for variable in self.variables_used: |
220 |
| - self.formatted_options.append(f"-variables-used={variable}") |
| 221 | + if isinstance(self.variables_used, str): |
| 222 | + self.formatted_options.append(f"-variables-used={self.variables_used}") |
| 223 | + elif isinstance(self.variables_used, list): |
| 224 | + for variable in self.variables_used: |
| 225 | + self.formatted_options.append(f"-variables-used={variable}") |
221 | 226 | if self.optimize_unused_variables is not None:
|
222 | 227 | self.formatted_options.append("-optimize-unused-variables")
|
223 | 228 | if self.ptxas_options is not None:
|
224 |
| - for opt in self.ptxas_options: |
225 |
| - self.formatted_options.append(f"-Xptxas={opt}") |
| 229 | + if isinstance(self.ptxas_options, str): |
| 230 | + self.formatted_options.append(f"-Xptxas={self.ptxas_options}") |
| 231 | + elif is_sequence(self.ptxas_options): |
| 232 | + for opt in self.ptxas_options: |
| 233 | + self.formatted_options.append(f"-Xptxas={opt}") |
226 | 234 | if self.split_compile is not None:
|
227 | 235 | self.formatted_options.append(f"-split-compile={self.split_compile}")
|
228 | 236 | if self.split_compile_extended is not None:
|
|
0 commit comments