Skip to content

Commit 973f9fc

Browse files
authored
[ParamManager][Redo] Use BundleModelParams for transform_dequantize (mlc-ai#1127)
Prior to this commit, `ParamManager.transform_quantize` function took as input functions with separate parameters for each weight tensor, and produced output functions with a tuple parameter for all weights. Because `LiftTransformParams` had the same convention, neither could be applied as part of the same build flow. This commit updates `ParamManager.transform_quantize` pass to produce outputs with separate tensor parameters, using the `BundleModelParams` transform to later combine them into a single tuple parameter. The analogous change was also performed for `LiftTransformParams` as part of apache/tvm#15657. In addition, prior to this commit, the `ParamManager.transform_dequantize` function operated directly on a `IRModule` object. As a result, any debug instrumentation (e.g. before/after printouts for each pass, before/after verification with `relax.analysis.well_formed`, etc.) did not apply to this `transform_dequantize`. This commit updates `ParamManager.transform_dequantize` to return a `ir.transform.Pass`. This commit is a repeat of the reverted PR mlc-ai#1056. This PR resolves the bug in the earlier implementation by removing the call to `.without_attr("num_input")` in `ParamReplacer.rewrite_func`. This follows an analogous update in `LiftTransformParams`, preserving the `"num_input"` attribute for use in `BundleModelParams`.
1 parent a4279e3 commit 973f9fc

File tree

2 files changed

+56
-63
lines changed

2 files changed

+56
-63
lines changed

mlc_llm/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ def mod_transform_before_build(
420420
if args.model.lower().startswith("rwkv-"):
421421
model_names += ["reset_kv_cache"]
422422

423-
mod = param_manager.transform_dequantize(mod)
423+
mod = param_manager.transform_dequantize()(mod)
424+
mod = relax.transform.BundleModelParams()(mod)
424425

425426
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"]
426427
mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod)

mlc_llm/relax_model/param_manager.py

Lines changed: 54 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def set_param_loading_func(
369369
else:
370370
self.pidx2pname = dict()
371371

372-
def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule:
372+
def transform_dequantize(self) -> tvm.ir.transform.Pass:
373373
"""Apply dequantization to the input IRModule.
374374
375375
Parameters
@@ -386,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule:
386386
The IRModule updated with the dequantization computation.
387387
"""
388388

389-
# For each Relax function in the input IRModule (e.g., "prefill"),
390-
# we create its input relax.Var of all the quantized data, and
391-
# store the mapping from function name to the var.
392-
func2param_var: Dict[str, relax.Var] = {}
393-
for gv, func in mod.functions.items():
394-
if not isinstance(func, relax.Function):
395-
continue
396-
if func.attrs is None or not "num_input" in func.attrs:
397-
continue
398-
func2param_var[gv.name_hint] = relax.Var(
399-
"params", self.get_quantized_param_info(gv.name_hint)
400-
)
389+
@tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize")
390+
def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule:
391+
# For each Relax function in the input IRModule (e.g., "prefill"),
392+
# we create its input relax.Var of all the quantized data, and
393+
# store the mapping from function name to the var.
394+
func_name_to_quantized_params: Dict[str, List[relax.Var]] = {}
401395

402-
# Cache mapping to avoid duplicate dequantization.
403-
dequantized_cache: Dict[relax.Var, relax.Var] = {}
396+
for gv, func in mod.functions.items():
397+
if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs:
398+
quantized_param_info = self.get_quantized_param_info(gv.name_hint)
399+
param_vars = [
400+
relax.Var(f"param_{i}", info)
401+
for i, info in enumerate(quantized_param_info.fields)
402+
]
403+
func_name_to_quantized_params[gv.name_hint] = param_vars
404404

405-
# Define a var replacement function for applying dequantization.
406-
def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var:
407-
if var in dequantized_cache:
408-
return dequantized_cache[var]
409-
assert var in self.func_raw_param_map
410-
func_name, param = self.func_raw_param_map[var]
411-
dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name)
412-
dequantized_cache[var] = dequantized
413-
return dequantized
405+
# Cache mapping to avoid duplicate dequantization.
406+
dequantized_cache: Dict[relax.Var, relax.Var] = {}
414407

415-
# Create the function mutator for applying dequantization.
416-
replacer = ParamReplacer(mod, func2param_var, f_replace)
417-
# Update the input IRModule with dequantization.
418-
mod = replacer.transform()
408+
# Define a var replacement function for applying dequantization.
409+
def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var:
410+
if var in dequantized_cache:
411+
return dequantized_cache[var]
412+
assert var in self.func_raw_param_map
419413

420-
return mod
414+
func_name, param = self.func_raw_param_map[var]
415+
quantized_params = func_name_to_quantized_params[func_name]
416+
relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]]
417+
418+
dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name)
419+
420+
dequantized_cache[var] = dequantized
421+
return dequantized
422+
423+
# Create the function mutator for applying dequantization.
424+
replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace)
425+
# Update the input IRModule with dequantization.
426+
mod = replacer.transform()
427+
428+
return mod
429+
430+
return transform_func
421431

422432
def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]:
423433
bb = relax.BlockBuilder()
@@ -697,10 +707,9 @@ def _register_param(
697707
def _dequantize(
698708
self,
699709
param: Parameter,
700-
quantized_tuple: relax.Var,
710+
qparams: List[relax.Var],
701711
bb: relax.BlockBuilder,
702712
func_name: str,
703-
qparams: List[relax.Var] = None,
704713
) -> relax.Var:
705714
"""Applying dequantization to the input parameter.
706715
This method is called by `transform_module` below, and is not
@@ -711,30 +720,13 @@ def _dequantize(
711720
param : Parameter
712721
The parameter whose quantized tensors are to be dequantized.
713722
714-
quantized_tuple : relax.Var
715-
The relax.Var of the quantized tensors of all parameters in the model.
716-
717-
bb : relax.BlockBuilder
718-
The Relax BlockBuilder used for inserting the dequantization computations.
719-
720-
func_name : str
721-
The name of the function which dequantization is applied to.
722-
723723
qparams : List[relax.Var]
724-
The quantized parts of the parameter.
725-
By default it is `None`, in which case we will get the quantized parts
726-
from `quantized_tuple`.
724+
The relax.Var of the quantized tensors of all parameters in the model.
727725
728726
Returns
729727
-------
730728
The dequantized parameter, in the form of a relax.Var.
731729
"""
732-
if not qparams:
733-
# Get the corresponding Relax vars of the quantized tensors of this parameter.
734-
qparams: List[relax.Var] = []
735-
for qparam_idx in self.param2qrange[param]:
736-
qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx)))
737-
738730
# Get the dequantization function of this parameter.
739731
f_dequantize = param.quant_spec.get_dequantize_func(
740732
param_info=param.param_info_dict[func_name],
@@ -789,7 +781,7 @@ class ParamReplacer(PyExprMutator):
789781
mod : tvm.IRModule
790782
The IRModule of the model to be updated.
791783
792-
func2param_var : Dict[str, relax.Var]
784+
func_name_to_quantized_params : Dict[str, List[relax.Var]]
793785
The mapping from each function name to its input var of quantized data tuple.
794786
795787
f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var]
@@ -801,7 +793,7 @@ class ParamReplacer(PyExprMutator):
801793
"""
802794

803795
mod: tvm.IRModule
804-
func2param_var: Dict[str, relax.Var]
796+
func_name_to_quantized_params: Dict[str, List[relax.Var]]
805797
f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var]
806798
param_set: Set[relax.Var]
807799

@@ -810,12 +802,12 @@ class ParamReplacer(PyExprMutator):
810802
def __init__(
811803
self,
812804
mod: tvm.IRModule,
813-
func2param_var: Dict[str, relax.Var],
805+
func_name_to_quantized_params: Dict[str, relax.Var],
814806
f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var],
815807
):
816808
super().__init__(mod)
817809
self.mod = mod
818-
self.func2param_var = func2param_var
810+
self.func_name_to_quantized_params = func_name_to_quantized_params
819811
self.f_replace = f_replace
820812
self.cur_func_name = ""
821813

@@ -827,31 +819,31 @@ def transform(self) -> tvm.IRModule:
827819
continue
828820

829821
assert (
830-
gv.name_hint in self.func2param_var
831-
), f"{gv.name_hint} not in {self.func2param_var}"
832-
self.cur_func_name = gv.name_hint
833-
updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint])
822+
gv.name_hint in self.func_name_to_quantized_params
823+
), f"{gv.name_hint} not in {self.func_name_to_quantized_params}"
824+
updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint])
834825
updated_func = remove_all_unused(updated_func)
835826
self.builder_.update_func(gv, updated_func)
836827
return self.builder_.get()
837828

838-
def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function:
829+
def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function:
839830
num_input = int(func.attrs["num_input"])
840831
self.param_set = set(func.params[num_input:])
841832

842833
body = self.visit_expr(func.body)
843834
return relax.Function(
844-
params=func.params[:num_input] + [param_var],
835+
params=func.params[:num_input] + quantized_params,
845836
body=body,
846837
ret_struct_info=func.ret_struct_info,
847838
is_pure=func.is_pure,
848839
attrs=func.attrs,
849-
).without_attr("num_input")
840+
)
850841

851842
def visit_var_(self, var: Var) -> Expr:
852-
if var not in self.param_set:
843+
if var in self.param_set:
844+
return self.f_replace(var, self.builder_)
845+
else:
853846
return super().visit_var_(var)
854-
return self.f_replace(var, self.builder_, self.cur_func_name)
855847

856848

857849
##################################################################

0 commit comments

Comments
 (0)