|
1 | 1 | # pylint: disable=missing-docstring,invalid-name
|
2 | 2 | import argparse
|
| 3 | +import functools |
3 | 4 | import json
|
4 | 5 | import os
|
5 | 6 | import shutil
|
|
17 | 18 | )
|
18 | 19 |
|
19 | 20 |
|
| 21 | +def wrap_tqdm_counter(func, **tqdm_kwargs): |
| 22 | + # tqdm isn't a hard requirement, so return the original function |
| 23 | + # if it isn't available. |
| 24 | + try: |
| 25 | + from tqdm import tqdm |
| 26 | + except ImportError: |
| 27 | + return func |
| 28 | + |
| 29 | + pbar = tqdm(**tqdm_kwargs) |
| 30 | + |
| 31 | + @functools.wraps(func) |
| 32 | + def inner(*args, **kwargs): |
| 33 | + pbar.update(1) |
| 34 | + return func(*args, **kwargs) |
| 35 | + |
| 36 | + return inner |
| 37 | + |
| 38 | + |
20 | 39 | def argparse_postproc_common(args: argparse.Namespace) -> None:
|
21 | 40 | if hasattr(args, "device_name"):
|
22 | 41 | if args.device_name == "auto":
|
@@ -198,6 +217,12 @@ def convert_weights(
|
198 | 217 | # memory usage when loading torch weights as well as acceleration.
|
199 | 218 | mod_transform = param_mgr.create_parameter_transformation()
|
200 | 219 |
|
| 220 | + # Save the number of parameters before we lower mod_transform, so |
| 221 | + # we can use them in the progress bar. |
| 222 | + transform_func = mod_transform["transform_params"] |
| 223 | + num_original_params = len(transform_func.params[0].struct_info.fields) |
| 224 | + num_transformed_params = len(transform_func.struct_info.ret.fields) |
| 225 | + |
201 | 226 | # Remove the dataflow block inside the param transform function,
|
202 | 227 | # so that the LazyTransformParams pass can be applied.
|
203 | 228 | mod_transform = relax.transform.ToNonDataflow()(mod_transform)
|
@@ -227,6 +252,14 @@ def convert_weights(
|
227 | 252 | device,
|
228 | 253 | device_cpu,
|
229 | 254 | )
|
| 255 | + |
| 256 | + get_item = wrap_tqdm_counter( |
| 257 | + get_item, desc="Get old param", position=0, unit="tensors", total=num_original_params |
| 258 | + ) |
| 259 | + set_item = wrap_tqdm_counter( |
| 260 | + set_item, desc="Set new param", position=1, unit="tensors", total=num_transformed_params |
| 261 | + ) |
| 262 | + |
230 | 263 | tvm.register_func(func_name="get_item", f=get_item, override=True)
|
231 | 264 | tvm.register_func(func_name="set_item", f=set_item, override=True)
|
232 | 265 |
|
|
0 commit comments