Skip to content

Commit 0a25374

Browse files
authored
Migrate Compiler Passes (mlc-ai#1150)
1 parent 2193767 commit 0a25374

14 files changed

+808
-96
lines changed

python/mlc_chat/compiler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency,
33
but users could optionally import it if they want to use the compiler.
44
"""
5+
from . import compiler_pass
56
from .compile import ( # pylint: disable=redefined-builtin
67
CompileArgs,
78
OptimizationFlags,
Lines changed: 18 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,15 @@
11
"""Python entrypoint of compilation."""
2-
import argparse
32
import dataclasses
4-
import logging
53
from io import StringIO
64
from pathlib import Path
75
from typing import Callable
86

9-
from mlc_chat.compiler.model import Model
10-
from tvm import IRModule # pylint: disable=wrong-import-order
11-
from tvm.target import Target # pylint: disable=wrong-import-order
7+
from tvm import IRModule, relax
8+
from tvm.target import Target
129

10+
from ..compiler.model import Model
1311
from ..support.style import bold
14-
15-
logger = logging.getLogger(__name__)
16-
17-
18-
@dataclasses.dataclass
19-
class OptimizationFlags:
20-
"""Optiization flags"""
21-
22-
cutlass_attn: bool = True
23-
cutlass_norm: bool = True
24-
cublas_gemm: bool = False
25-
cudagraph: bool = False
26-
27-
def __repr__(self) -> str:
28-
out = StringIO()
29-
print(f"cutlass_attn={int(self.cutlass_attn)}", file=out, end="")
30-
print(f";cutlass_norm={int(self.cutlass_norm)}", file=out, end="")
31-
print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="")
32-
print(f";cudagraph={int(self.cudagraph)}", file=out, end="")
33-
return out.getvalue().rstrip()
34-
35-
@staticmethod
36-
def from_str(source: str) -> "OptimizationFlags":
37-
"""Parse optimization flags from a string."""
38-
39-
if source in OPT_FLAG_PRESET:
40-
return OPT_FLAG_PRESET[source]
41-
42-
def boolean(value: str) -> bool:
43-
if value == "0":
44-
return False
45-
if value == "1":
46-
return True
47-
raise ValueError(f"Invalid boolean value: {value}")
48-
49-
parser = argparse.ArgumentParser(description="optimization flags")
50-
parser.add_argument("--cutlass_attn", type=boolean, default=True)
51-
parser.add_argument("--cutlass_norm", type=boolean, default=True)
52-
parser.add_argument("--cublas_gemm", type=boolean, default=False)
53-
parser.add_argument("--cudagraph", type=boolean, default=False)
54-
results = parser.parse_args([f"--{i}" for i in source.split(";") if i])
55-
return OptimizationFlags(
56-
cutlass_attn=results.cutlass_attn,
57-
cutlass_norm=results.cutlass_norm,
58-
cublas_gemm=results.cublas_gemm,
59-
cudagraph=results.cudagraph,
60-
)
12+
from .flags_optimization import OptimizationFlags
6113

6214

6315
@dataclasses.dataclass
@@ -86,6 +38,19 @@ def _echo_args(args: CompileArgs) -> None:
8638
print(out.getvalue().rstrip())
8739

8840

41+
def _compile(args: CompileArgs):
42+
model_config = args.model.config.from_file(args.config)
43+
model = args.model.model(model_config)
44+
mod, named_params = model.export_tvm(
45+
spec=model.get_default_spec(), # type: ignore
46+
)
47+
with args.target:
48+
mod = relax.get_pipeline("mlc_llm")(mod)
49+
mod.show(black_format=False)
50+
for name, param in named_params:
51+
print(f"{name}: {param.shape} {param.dtype}")
52+
53+
8954
def compile( # pylint: disable=too-many-arguments,redefined-builtin
9055
config: Path,
9156
quantization,
@@ -101,39 +66,4 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
10166
config, quantization, model_type, target, opt, build_func, prefix_symbols, output
10267
)
10368
_echo_args(args)
104-
model_config = args.model.config.from_file(args.config)
105-
model = args.model.model(model_config)
106-
mod, named_params = model.export_tvm(
107-
spec=model.get_default_spec(), # type: ignore
108-
)
109-
mod.show(black_format=False)
110-
for name, param in named_params:
111-
print(f"{name}: {param.shape} {param.dtype}")
112-
113-
114-
OPT_FLAG_PRESET = {
115-
"O0": OptimizationFlags(
116-
cutlass_attn=False,
117-
cutlass_norm=False,
118-
cublas_gemm=False,
119-
cudagraph=False,
120-
),
121-
"O1": OptimizationFlags(
122-
cutlass_attn=False,
123-
cutlass_norm=True,
124-
cublas_gemm=False,
125-
cudagraph=False,
126-
),
127-
"O2": OptimizationFlags(
128-
cutlass_attn=True,
129-
cutlass_norm=True,
130-
cublas_gemm=False,
131-
cudagraph=False,
132-
),
133-
"O3": OptimizationFlags(
134-
cutlass_attn=True,
135-
cutlass_norm=True,
136-
cublas_gemm=False,
137-
cudagraph=True,
138-
),
139-
}
69+
_compile(args)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""Compiler passes used in MLC LLM."""
2+
from . import pipeline as _pipeline
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""A compiler pass that cleans up undesired TIR attrs."""
2+
from typing import List
3+
4+
import tvm
5+
from tvm.ir.module import IRModule
6+
7+
8+
@tvm.transform.module_pass(opt_level=0, name="CleanUpTIRAttrs")
9+
class CleanUpTIRAttrs: # pylint: disable=too-few-public-methods
10+
"""A compiler pass that cleans up undesired TIR attrs."""
11+
12+
def __init__(self, attrs: List[str]):
13+
self.attrs = attrs
14+
15+
def transform_module(
16+
self,
17+
mod: IRModule,
18+
_ctx: tvm.transform.PassContext,
19+
) -> IRModule:
20+
"""IRModule-level transformation"""
21+
for g_var in list(mod.functions):
22+
func = mod[g_var]
23+
changed = False
24+
for attr in self.attrs:
25+
if func.attrs is not None and attr in func.attrs:
26+
func = func.without_attr(attr)
27+
changed = True
28+
break
29+
if changed:
30+
mod[g_var] = func
31+
return mod
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""A compiler pass that fuses decode + matmul + elementwise."""
2+
import tvm
3+
from tvm import IRModule, relax
4+
from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard
5+
6+
7+
@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise")
8+
class FuseDecodeMatmulEwise: # pylint: disable=too-few-public-methods
9+
"""A compiler pass that fuses decode + matmul + elementwise."""
10+
11+
def transform_module(
12+
self,
13+
mod: IRModule,
14+
_ctx: tvm.transform.PassContext,
15+
) -> IRModule:
16+
"""IRModule-level transformation"""
17+
for n_aux_tensor in [1, 2, 3, 4]:
18+
for match_ewise in [0, 1, 2, 6]:
19+
if match_ewise == 6 and n_aux_tensor != 4:
20+
continue
21+
mod = relax.transform.FuseOpsByPattern(
22+
[
23+
(
24+
"decode_matmul",
25+
*_pattern(match_ewise, n_aux_tensor),
26+
)
27+
]
28+
)(mod)
29+
mod = relax.transform.FuseTIR()(mod)
30+
return mod
31+
32+
33+
def _pattern(match_ewise: int, n_aux_tensor: int):
34+
# pylint: disable=invalid-name
35+
w_scaled = wildcard()
36+
x = wildcard()
37+
w = is_op("relax.call_tir")(
38+
GlobalVarPattern(),
39+
TuplePattern([w_scaled] + [wildcard() for _ in range(n_aux_tensor)]),
40+
add_constraint=False,
41+
)
42+
matmul = is_op("relax.call_tir")(
43+
GlobalVarPattern(),
44+
TuplePattern([x, w] + [wildcard() for _ in range(match_ewise)]),
45+
add_constraint=False,
46+
)
47+
# pylint: enable=invalid-name
48+
annotations = {
49+
"w_scaled": w_scaled,
50+
"x": x,
51+
"w": w,
52+
"matmul": matmul,
53+
}
54+
55+
def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool:
56+
call = ctx.annotated_expr["w"]
57+
if not isinstance(call, relax.Call):
58+
return False
59+
g_var = call.args[0]
60+
if not isinstance(g_var, relax.GlobalVar):
61+
return False
62+
return g_var.name_hint.startswith("decode") or g_var.name_hint.startswith("fused_decode")
63+
64+
def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool:
65+
call = ctx.annotated_expr["matmul"]
66+
if not isinstance(call, relax.Call):
67+
return False
68+
g_var = call.args[0]
69+
if not isinstance(g_var, relax.GlobalVar):
70+
return False
71+
return (
72+
g_var.name_hint.startswith("matmul")
73+
or g_var.name_hint.startswith("fused_matmul")
74+
or g_var.name_hint.startswith("NT_matmul")
75+
or g_var.name_hint.startswith("fused_NT_matmul")
76+
)
77+
78+
def _check(ctx: relax.transform.PatternCheckContext) -> bool:
79+
return _check_decoding(ctx) and _check_matmul(ctx)
80+
81+
return matmul, annotations, _check
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""A compiler pass that fuses decode + take."""
2+
import tvm
3+
from tvm import IRModule, relax, tir
4+
from tvm.relax.dpl.pattern import (
5+
GlobalVarPattern,
6+
TuplePattern,
7+
is_const,
8+
is_op,
9+
wildcard,
10+
)
11+
12+
13+
@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake")
14+
class FuseDecodeTake: # pylint: disable=too-few-public-methods
15+
"""A compiler pass that fuses decode + take."""
16+
17+
def transform_module(
18+
self,
19+
mod: IRModule,
20+
_ctx: tvm.transform.PassContext,
21+
) -> IRModule:
22+
"""IRModule-level transformation"""
23+
for n_aux_tensor in [2, 3]:
24+
for match_tir_vars in [False, True]:
25+
mod = relax.transform.FuseOpsByPattern(
26+
[
27+
(
28+
"decode_take",
29+
*_pattern(n_aux_tensor, match_tir_vars),
30+
)
31+
]
32+
)(mod)
33+
mod = relax.transform.FuseTIR()(mod)
34+
for g_var, func in mod.functions.items():
35+
name = g_var.name_hint
36+
if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)):
37+
mod = tvm.IRModule({"main": func})
38+
sch = tir.Schedule(mod)
39+
sch.compute_inline("decode")
40+
mod[g_var] = sch.mod["main"]
41+
return mod
42+
43+
44+
def _pattern(n_aux_tensor: int, match_tir_vars: bool):
45+
decode = is_op("relax.call_tir")(
46+
GlobalVarPattern(),
47+
TuplePattern([wildcard() for _ in range(n_aux_tensor)]),
48+
add_constraint=False,
49+
)
50+
indices = ~is_const()
51+
if match_tir_vars:
52+
call_tir_args_take = [
53+
GlobalVarPattern(),
54+
TuplePattern([decode, indices]),
55+
wildcard(),
56+
]
57+
else:
58+
call_tir_args_take = [
59+
GlobalVarPattern(),
60+
TuplePattern([decode, indices]),
61+
]
62+
take = is_op("relax.call_tir")(
63+
*call_tir_args_take,
64+
add_constraint=False,
65+
)
66+
annotations = {
67+
"take": take,
68+
"decode": decode,
69+
"indices": indices,
70+
}
71+
72+
def _check(ctx: relax.transform.PatternCheckContext) -> bool:
73+
take = ctx.annotated_expr["take"]
74+
decode = ctx.annotated_expr["decode"]
75+
if not isinstance(decode, relax.expr.Call):
76+
return False
77+
if not isinstance(take.args[0], relax.GlobalVar) or not isinstance(
78+
decode.args[0], relax.GlobalVar
79+
):
80+
return False
81+
return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint
82+
83+
return take, annotations, _check

0 commit comments

Comments
 (0)