Skip to content

Commit d202077

Browse files
authored
[ParamManager] Added progress bar for get_item/set_item (mlc-ai#1063)
1 parent 204860b commit d202077

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

mlc_llm/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=missing-docstring,invalid-name
22
import argparse
3+
import functools
34
import json
45
import os
56
import shutil
@@ -17,6 +18,24 @@
1718
)
1819

1920

21+
def wrap_tqdm_counter(func, **tqdm_kwargs):
22+
# tqdm isn't a hard requirement, so return the original function
23+
# if it isn't available.
24+
try:
25+
from tqdm import tqdm
26+
except ImportError:
27+
return func
28+
29+
pbar = tqdm(**tqdm_kwargs)
30+
31+
@functools.wraps(func)
32+
def inner(*args, **kwargs):
33+
pbar.update(1)
34+
return func(*args, **kwargs)
35+
36+
return inner
37+
38+
2039
def argparse_postproc_common(args: argparse.Namespace) -> None:
2140
if hasattr(args, "device_name"):
2241
if args.device_name == "auto":
@@ -198,6 +217,12 @@ def convert_weights(
198217
# memory usage when loading torch weights as well as acceleration.
199218
mod_transform = param_mgr.create_parameter_transformation()
200219

220+
# Save the number of parameters before we lower mod_transform, so
221+
# we can use them in the progress bar.
222+
transform_func = mod_transform["transform_params"]
223+
num_original_params = len(transform_func.params[0].struct_info.fields)
224+
num_transformed_params = len(transform_func.struct_info.ret.fields)
225+
201226
# Remove the dataflow block inside the param transform function,
202227
# so that the LazyTransformParams pass can be applied.
203228
mod_transform = relax.transform.ToNonDataflow()(mod_transform)
@@ -227,6 +252,14 @@ def convert_weights(
227252
device,
228253
device_cpu,
229254
)
255+
256+
get_item = wrap_tqdm_counter(
257+
get_item, desc="Get old param", position=0, unit="tensors", total=num_original_params
258+
)
259+
set_item = wrap_tqdm_counter(
260+
set_item, desc="Set new param", position=1, unit="tensors", total=num_transformed_params
261+
)
262+
230263
tvm.register_func(func_name="get_item", f=get_item, override=True)
231264
tvm.register_func(func_name="set_item", f=set_item, override=True)
232265

0 commit comments

Comments
 (0)