Skip to content

Commit 2ca7d15

Browse files
authored
[Fix] TIR block name of dequantization (mlc-ai#1177)
1 parent 1757777 commit 2ca7d15

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def transform_module(
3838
for g_var, func in mod.functions_items():
3939
name = g_var.name_hint
4040
if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)):
41-
mod = tvm.IRModule({"main": func})
42-
sch = tir.Schedule(mod)
41+
sch_mod = tvm.IRModule({"main": func})
42+
sch_mod = tir.transform.ForceNarrowIndexToInt32()(sch_mod)
43+
sch = tir.Schedule(sch_mod)
4344
sch.compute_inline("decode")
4445
mod[g_var] = sch.mod["main"]
4546
return mod

python/mlc_chat/compiler/quantization/group_quantization.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""The group quantization config"""
22
from dataclasses import dataclass, field
3-
from typing import Any, Callable, Dict, List, Optional, Tuple
3+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
44

55
import numpy as np
66
from tvm import DataType, DataTypeCode
@@ -128,7 +128,7 @@ def _dequantize(
128128
):
129129
tir_bin_mask = tir.const((1 << DataType(self.quantize_dtype).bits) - 1, self.storage_dtype)
130130
tir_max_int = tir.const(self.max_int_value, self.model_dtype)
131-
dequantized_weight = te.compute(
131+
return te.compute(
132132
shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage]
133133
if out_shape is None
134134
else out_shape,
@@ -149,8 +149,8 @@ def _dequantize(
149149
),
150150
scale[i, j // self.group_size],
151151
),
152+
name="decode",
152153
)
153-
return dequantized_weight
154154

155155
def quantize_weight(self, weight: NDArray) -> List[NDArray]:
156156
"""
@@ -186,8 +186,10 @@ def quantize_weight(self, weight: NDArray) -> List[NDArray]:
186186
if target is None:
187187
target = Target.from_device(dev)
188188
with target:
189-
mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
190-
dl.gpu.Reduction(), dl.gpu.GeneralReduction(), dl.gpu.Fallback()
189+
mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable
190+
dl.gpu.Reduction(),
191+
dl.gpu.GeneralReduction(),
192+
dl.gpu.Fallback(),
191193
)(mod)
192194
elif device_type == "cpu":
193195
target = "llvm"
@@ -400,7 +402,7 @@ def from_multilinear(
400402
out_dtype=multi_linear.out_dtype,
401403
)
402404

403-
def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name
405+
def forward(self, x: nn.Tensor) -> Sequence[nn.Tensor]: # pylint: disable=invalid-name
404406
"""
405407
Forward method for multi linear layer.
406408

python/mlc_chat/support/auto_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def detect_model_type(model_type: str, config: Path) -> "Model":
9595
f"Please explicitly specify `--model-type` instead"
9696
)
9797
model_type = cfg["model_type"]
98-
logger.info("%s Model type: %s", FOUND, model_type)
98+
logger.info("%s model type: %s", FOUND, model_type)
9999
if model_type not in MODELS:
100100
raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}")
101101
return MODELS[model_type]

0 commit comments

Comments
 (0)