|
1 | 1 | """Helper functioms for target auto-detection."""
|
2 | 2 | import logging
|
| 3 | +import os |
3 | 4 | from typing import TYPE_CHECKING, Callable, Optional, Tuple
|
4 | 5 |
|
5 | 6 | from tvm import IRModule, relax
|
6 | 7 | from tvm._ffi import register_func
|
7 | 8 | from tvm.contrib import tar, xcode
|
8 | 9 | from tvm.target import Target
|
9 | 10 |
|
10 |
| -from .style import green, red |
| 11 | +from .style import bold, green, red |
11 | 12 |
|
12 | 13 | if TYPE_CHECKING:
|
13 | 14 | from mlc_chat.compiler.compile import CompileArgs
|
@@ -38,6 +39,8 @@ def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, Bu
|
38 | 39 | target, build_func = _detect_target_gpu(target_hint)
|
39 | 40 | if target.host is None:
|
40 | 41 | target = Target(target, host=_detect_target_host(host_hint))
|
| 42 | + if target.kind.name == "cuda": |
| 43 | + _register_cuda_hook(target) |
41 | 44 | return target, build_func
|
42 | 45 |
|
43 | 46 |
|
@@ -223,6 +226,37 @@ def build(mod: IRModule, args: "CompileArgs"):
|
223 | 226 | return build
|
224 | 227 |
|
225 | 228 |
|
| 229 | +def _register_cuda_hook(target: Target): |
| 230 | + env_multi_arch = os.environ.get("MLC_MULTI_ARCH", None) |
| 231 | + if env_multi_arch is None: |
| 232 | + default_arch = target.attrs.get("arch", None) |
| 233 | + logger.info("Generating code for CUDA architecture: %s", bold(default_arch)) |
| 234 | + logger.info( |
| 235 | + "To produce multi-arch fatbin, set environment variable %s. " |
| 236 | + "Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90", |
| 237 | + bold("MLC_MULTI_ARCH"), |
| 238 | + ) |
| 239 | + multi_arch = None |
| 240 | + else: |
| 241 | + logger.info("%s %s: %s", FOUND, bold("MLC_MULTI_ARCH"), env_multi_arch) |
| 242 | + multi_arch = [int(x.strip()) for x in env_multi_arch.split(",")] |
| 243 | + logger.info("Generating code for CUDA architecture: %s", multi_arch) |
| 244 | + |
| 245 | + @register_func("tvm_callback_cuda_compile", override=True) |
| 246 | + def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument |
| 247 | + """use nvcc to generate fatbin code for better optimization""" |
| 248 | + from tvm.contrib import nvcc # pylint: disable=import-outside-toplevel |
| 249 | + |
| 250 | + if multi_arch is None: |
| 251 | + ptx = nvcc.compile_cuda(code, target_format="fatbin") |
| 252 | + else: |
| 253 | + arch = [] |
| 254 | + for compute_version in multi_arch: |
| 255 | + arch += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] |
| 256 | + ptx = nvcc.compile_cuda(code, target_format="fatbin", arch=arch) |
| 257 | + return ptx |
| 258 | + |
| 259 | + |
226 | 260 | AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"]
|
227 | 261 |
|
228 | 262 | PRESET = {
|
|
0 commit comments