|
13 | 13 |
|
14 | 14 | from .. import quantization
|
15 | 15 | from .modules import named_parameters
|
| 16 | +from ..transform import ReorderTransformFunc |
16 | 17 |
|
17 | 18 |
|
18 | 19 | def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any:
|
@@ -274,6 +275,31 @@ def register_params(
|
274 | 275 |
|
275 | 276 | self.params_in_func[func_name].append(param)
|
276 | 277 |
|
| 278 | + def run_pre_quantize(self, model_path: str): |
| 279 | + if self.f_run_prequantize is not None: |
| 280 | + model_path = self.f_run_prequantize(model_path) |
| 281 | + |
| 282 | + self.model_path = model_path |
| 283 | + return model_path |
| 284 | + |
| 285 | + def init_torch_pname_to_bin_name(self, use_safetensors: bool): |
| 286 | + assert hasattr(self, "model_path"), ( |
| 287 | + "Must call either set_param_loading_func or run_pre_quantize " |
| 288 | + "before init_torch_pname_to_bin_name" |
| 289 | + ) |
| 290 | + |
| 291 | + if self.pidx2pname: |
| 292 | + mapping = load_torch_pname2binname_map( |
| 293 | + self.model_path, |
| 294 | + use_safetensors, |
| 295 | + set(self.pidx2pname.values()), |
| 296 | + self.f_convert_pname_fwd, |
| 297 | + ) |
| 298 | + else: |
| 299 | + mapping = {} |
| 300 | + |
| 301 | + self.torch_pname2binname = mapping |
| 302 | + |
277 | 303 | def set_param_loading_func(
|
278 | 304 | self,
|
279 | 305 | model_path: str,
|
@@ -726,6 +752,33 @@ def _dequantize(
|
726 | 752 | # Apply the dequantization function.
|
727 | 753 | return bb.emit(f_dequantize(bb, qparams))
|
728 | 754 |
|
| 755 | + def create_parameter_transformation(self, optimize_parameter_order: bool = True): |
| 756 | + """Produce an IRModule that can transform the parameters |
| 757 | +
|
| 758 | + Parameters |
| 759 | + ---------- |
| 760 | + optimize_parameter_order: bool |
| 761 | +
|
| 762 | + If true, reorder the parameter transformations to |
| 763 | + prioritize operations that use a currently-open file. If |
| 764 | + false, transform the parameters in their default order. |
| 765 | +
|
| 766 | + Returns |
| 767 | + ------- |
| 768 | + tvm.IRModule |
| 769 | + The transformation module |
| 770 | +
|
| 771 | + """ |
| 772 | + mod = _create_quantize_func(self) |
| 773 | + if optimize_parameter_order: |
| 774 | + reorder_pass = ReorderTransformFunc( |
| 775 | + self.pidx2pname, |
| 776 | + self.torch_pname2binname, |
| 777 | + self.f_convert_pname_fwd, |
| 778 | + ) |
| 779 | + mod = reorder_pass(mod) |
| 780 | + return mod |
| 781 | + |
729 | 782 |
|
730 | 783 | @mutator
|
731 | 784 | class ParamReplacer(PyExprMutator):
|
@@ -868,7 +921,7 @@ def load_torch_pname2binname_map(
|
868 | 921 | return torch_pname2binname
|
869 | 922 |
|
870 | 923 |
|
871 |
| -def create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: |
| 924 | +def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: |
872 | 925 | """Construct the Relax function which computes quantization.
|
873 | 926 | This method is called by `transform_module` below, and is not
|
874 | 927 | directly invoked outside the class.
|
|
0 commit comments