Skip to content

Commit 8184431

Browse files
authored
[ParamManager] Cleanup creation of quantization IRModule (mlc-ai#1053)
This commit replaces the single-parameter `relax_model.param_manager.create_quantize_func` function with a method on the `ParamManager`, `create_parameter_transformation`. This avoids potential typos between `param_manager` as the imported Python module `mlc_llm.relax_model.param_manager` and an instance of the `ParamManager` class named `param_manager`, and makes the functionality easier to find. This function also takes an optional `optimize_parameter_order` flag, defaulting to `True`, which applies the `ReorderTransformFunc` pass. Since the `ReorderTransformFunc` is intended to be used with several configuration objects owned by `ParamManager`, this simplifies the common path of producing an optimally-ordered parameter transformation module.
1 parent 481cd92 commit 8184431

File tree

3 files changed

+61
-23
lines changed

3 files changed

+61
-23
lines changed

mlc_llm/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,10 @@ def build_model_from_args(args: argparse.Namespace):
638638
qspec_updater.visit_module(mod)
639639

640640
if not args.build_model_only:
641+
# Run pre-quantization if provided.
642+
args.model_path = param_manager.run_pre_quantize(args.model_path)
643+
param_manager.init_torch_pname_to_bin_name(args.use_safetensors)
644+
641645
new_params = utils.convert_weights(param_manager, params, args)
642646
utils.save_params(new_params, args.artifact_path)
643647
if args.model_category != "minigpt":

mlc_llm/relax_model/param_manager.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .. import quantization
1515
from .modules import named_parameters
16+
from ..transform import ReorderTransformFunc
1617

1718

1819
def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any:
@@ -274,6 +275,31 @@ def register_params(
274275

275276
self.params_in_func[func_name].append(param)
276277

278+
def run_pre_quantize(self, model_path: str):
279+
if self.f_run_prequantize is not None:
280+
model_path = self.f_run_prequantize(model_path)
281+
282+
self.model_path = model_path
283+
return model_path
284+
285+
def init_torch_pname_to_bin_name(self, use_safetensors: bool):
286+
assert hasattr(self, "model_path"), (
287+
"Must call either set_param_loading_func or run_pre_quantize "
288+
"before init_torch_pname_to_bin_name"
289+
)
290+
291+
if self.pidx2pname:
292+
mapping = load_torch_pname2binname_map(
293+
self.model_path,
294+
use_safetensors,
295+
set(self.pidx2pname.values()),
296+
self.f_convert_pname_fwd,
297+
)
298+
else:
299+
mapping = {}
300+
301+
self.torch_pname2binname = mapping
302+
277303
def set_param_loading_func(
278304
self,
279305
model_path: str,
@@ -726,6 +752,33 @@ def _dequantize(
726752
# Apply the dequantization function.
727753
return bb.emit(f_dequantize(bb, qparams))
728754

755+
def create_parameter_transformation(self, optimize_parameter_order: bool = True):
756+
"""Produce an IRModule that can transform the parameters
757+
758+
Parameters
759+
----------
760+
optimize_parameter_order: bool
761+
762+
If true, reorder the parameter transformations to
763+
prioritize operations that use a currently-open file. If
764+
false, transform the parameters in their default order.
765+
766+
Returns
767+
-------
768+
tvm.IRModule
769+
The transformation module
770+
771+
"""
772+
mod = _create_quantize_func(self)
773+
if optimize_parameter_order:
774+
reorder_pass = ReorderTransformFunc(
775+
self.pidx2pname,
776+
self.torch_pname2binname,
777+
self.f_convert_pname_fwd,
778+
)
779+
mod = reorder_pass(mod)
780+
return mod
781+
729782

730783
@mutator
731784
class ParamReplacer(PyExprMutator):
@@ -868,7 +921,7 @@ def load_torch_pname2binname_map(
868921
return torch_pname2binname
869922

870923

871-
def create_quantize_func(param_manager: ParamManager) -> tvm.IRModule:
924+
def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule:
872925
"""Construct the Relax function which computes quantization.
873926
This method is called by `transform_module` below, and is not
874927
directly invoked outside the class.

mlc_llm/utils.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .quantization import quantization_schemes
1212
from .relax_model import param_manager
13-
from .transform import ReorderTransformFunc
13+
1414

1515
supported_model_types = set(
1616
["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral", "stablelm_epoch"]
@@ -192,31 +192,12 @@ def convert_weights(
192192
model_params: List[Optional[tvm.nd.NDArray]],
193193
args: argparse.Namespace,
194194
):
195-
# Run pre-quantization if provided.
196-
if param_mgr.f_run_prequantize is not None:
197-
args.model_path = param_mgr.f_run_prequantize(args.model_path)
198-
param_mgr.model_path = args.model_path
199-
param_mgr.torch_pname2binname = (
200-
param_manager.load_torch_pname2binname_map(
201-
args.model_path,
202-
args.use_safetensors,
203-
set(param_mgr.pidx2pname.values()),
204-
param_mgr.f_convert_pname_fwd,
205-
)
206-
if len(param_mgr.pidx2pname) != 0
207-
else dict()
208-
)
209-
210195
# Create the quantization function.
211196
# We first create an initial one, then reorder it according to each
212197
# weight's location in the binary files, in the purpose of reducing
213198
# memory usage when loading torch weights as well as acceleration.
214-
mod_transform = param_manager.create_quantize_func(param_mgr)
215-
mod_transform = ReorderTransformFunc(
216-
param_mgr.pidx2pname,
217-
param_mgr.torch_pname2binname,
218-
param_mgr.f_convert_pname_fwd,
219-
)(mod_transform)
199+
mod_transform = param_mgr.create_parameter_transformation()
200+
220201
# Remove the dataflow block inside the param transform function,
221202
# so that the LazyTransformParams pass can be applied.
222203
mod_transform = relax.transform.ToNonDataflow()(mod_transform)

0 commit comments

Comments
 (0)