Skip to content

Commit 02d1e57

Browse files
authored
Support CUDA Multi-Arch Compilation (mlc-ai#1166)
1 parent 8438b27 commit 02d1e57

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

python/mlc_chat/compiler/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _compile(args: CompileArgs):
5656
mod = relax.get_pipeline("mlc_llm")(mod)
5757
logger.info("Generating code using TVM Unity")
5858
args.build_func(mod, args)
59-
logger.info("Code dumped to: %s", args.output)
59+
logger.info("Code dumped to: %s", bold(str(args.output)))
6060

6161

6262
def compile( # pylint: disable=too-many-arguments,redefined-builtin

python/mlc_chat/support/auto_target.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Helper functioms for target auto-detection."""
22
import logging
3+
import os
34
from typing import TYPE_CHECKING, Callable, Optional, Tuple
45

56
from tvm import IRModule, relax
67
from tvm._ffi import register_func
78
from tvm.contrib import tar, xcode
89
from tvm.target import Target
910

10-
from .style import green, red
11+
from .style import bold, green, red
1112

1213
if TYPE_CHECKING:
1314
from mlc_chat.compiler.compile import CompileArgs
@@ -38,6 +39,8 @@ def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, Bu
3839
target, build_func = _detect_target_gpu(target_hint)
3940
if target.host is None:
4041
target = Target(target, host=_detect_target_host(host_hint))
42+
if target.kind.name == "cuda":
43+
_register_cuda_hook(target)
4144
return target, build_func
4245

4346

@@ -223,6 +226,37 @@ def build(mod: IRModule, args: "CompileArgs"):
223226
return build
224227

225228

229+
def _register_cuda_hook(target: Target):
230+
env_multi_arch = os.environ.get("MLC_MULTI_ARCH", None)
231+
if env_multi_arch is None:
232+
default_arch = target.attrs.get("arch", None)
233+
logger.info("Generating code for CUDA architecture: %s", bold(default_arch))
234+
logger.info(
235+
"To produce multi-arch fatbin, set environment variable %s. "
236+
"Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90",
237+
bold("MLC_MULTI_ARCH"),
238+
)
239+
multi_arch = None
240+
else:
241+
logger.info("%s %s: %s", FOUND, bold("MLC_MULTI_ARCH"), env_multi_arch)
242+
multi_arch = [int(x.strip()) for x in env_multi_arch.split(",")]
243+
logger.info("Generating code for CUDA architecture: %s", multi_arch)
244+
245+
@register_func("tvm_callback_cuda_compile", override=True)
246+
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
247+
"""use nvcc to generate fatbin code for better optimization"""
248+
from tvm.contrib import nvcc # pylint: disable=import-outside-toplevel
249+
250+
if multi_arch is None:
251+
ptx = nvcc.compile_cuda(code, target_format="fatbin")
252+
else:
253+
arch = []
254+
for compute_version in multi_arch:
255+
arch += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]
256+
ptx = nvcc.compile_cuda(code, target_format="fatbin", arch=arch)
257+
return ptx
258+
259+
226260
AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"]
227261

228262
PRESET = {

0 commit comments

Comments
 (0)