@@ -369,7 +369,7 @@ def set_param_loading_func(
369
369
else :
370
370
self .pidx2pname = dict ()
371
371
372
- def transform_dequantize (self , mod : tvm . IRModule ) -> tvm .IRModule :
372
+ def transform_dequantize (self ) -> tvm .ir . transform . Pass :
373
373
"""Apply dequantization to the input IRModule.
374
374
375
375
Parameters
@@ -386,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule:
386
386
The IRModule updated with the dequantization computation.
387
387
"""
388
388
389
- # For each Relax function in the input IRModule (e.g., "prefill"),
390
- # we create its input relax.Var of all the quantized data, and
391
- # store the mapping from function name to the var.
392
- func2param_var : Dict [str , relax .Var ] = {}
393
- for gv , func in mod .functions .items ():
394
- if not isinstance (func , relax .Function ):
395
- continue
396
- if func .attrs is None or not "num_input" in func .attrs :
397
- continue
398
- func2param_var [gv .name_hint ] = relax .Var (
399
- "params" , self .get_quantized_param_info (gv .name_hint )
400
- )
389
+ @tvm .ir .transform .module_pass (opt_level = 0 , name = "ParamManager.transform_dequantize" )
390
+ def transform_func (mod : tvm .IRModule , _context ) -> tvm .IRModule :
391
+ # For each Relax function in the input IRModule (e.g., "prefill"),
392
+ # we create its input relax.Var of all the quantized data, and
393
+ # store the mapping from function name to the var.
394
+ func_name_to_quantized_params : Dict [str , List [relax .Var ]] = {}
401
395
402
- # Cache mapping to avoid duplicate dequantization.
403
- dequantized_cache : Dict [relax .Var , relax .Var ] = {}
396
+ for gv , func in mod .functions .items ():
397
+ if isinstance (func , relax .Function ) and func .attrs and "num_input" in func .attrs :
398
+ quantized_param_info = self .get_quantized_param_info (gv .name_hint )
399
+ param_vars = [
400
+ relax .Var (f"param_{ i } " , info )
401
+ for i , info in enumerate (quantized_param_info .fields )
402
+ ]
403
+ func_name_to_quantized_params [gv .name_hint ] = param_vars
404
404
405
- # Define a var replacement function for applying dequantization.
406
- def f_replace (var : relax .Var , bb : relax .BlockBuilder , func_name : str ) -> relax .Var :
407
- if var in dequantized_cache :
408
- return dequantized_cache [var ]
409
- assert var in self .func_raw_param_map
410
- func_name , param = self .func_raw_param_map [var ]
411
- dequantized = self ._dequantize (param , func2param_var [func_name ], bb , func_name )
412
- dequantized_cache [var ] = dequantized
413
- return dequantized
405
+ # Cache mapping to avoid duplicate dequantization.
406
+ dequantized_cache : Dict [relax .Var , relax .Var ] = {}
414
407
415
- # Create the function mutator for applying dequantization.
416
- replacer = ParamReplacer (mod , func2param_var , f_replace )
417
- # Update the input IRModule with dequantization.
418
- mod = replacer .transform ()
408
+ # Define a var replacement function for applying dequantization.
409
+ def f_replace (var : relax .Var , bb : relax .BlockBuilder ) -> relax .Var :
410
+ if var in dequantized_cache :
411
+ return dequantized_cache [var ]
412
+ assert var in self .func_raw_param_map
419
413
420
- return mod
414
+ func_name , param = self .func_raw_param_map [var ]
415
+ quantized_params = func_name_to_quantized_params [func_name ]
416
+ relevant_quantized_params = [quantized_params [i ] for i in self .param2qrange [param ]]
417
+
418
+ dequantized = self ._dequantize (param , relevant_quantized_params , bb , func_name )
419
+
420
+ dequantized_cache [var ] = dequantized
421
+ return dequantized
422
+
423
+ # Create the function mutator for applying dequantization.
424
+ replacer = ParamReplacer (mod , func_name_to_quantized_params , f_replace )
425
+ # Update the input IRModule with dequantization.
426
+ mod = replacer .transform ()
427
+
428
+ return mod
429
+
430
+ return transform_func
421
431
422
432
def get_quantized_param_info (self , func_name : str ) -> List [relax .TensorStructInfo ]:
423
433
bb = relax .BlockBuilder ()
@@ -697,10 +707,9 @@ def _register_param(
697
707
def _dequantize (
698
708
self ,
699
709
param : Parameter ,
700
- quantized_tuple : relax .Var ,
710
+ qparams : List [ relax .Var ] ,
701
711
bb : relax .BlockBuilder ,
702
712
func_name : str ,
703
- qparams : List [relax .Var ] = None ,
704
713
) -> relax .Var :
705
714
"""Applying dequantization to the input parameter.
706
715
This method is called by `transform_module` below, and is not
@@ -711,30 +720,13 @@ def _dequantize(
711
720
param : Parameter
712
721
The parameter whose quantized tensors are to be dequantized.
713
722
714
- quantized_tuple : relax.Var
715
- The relax.Var of the quantized tensors of all parameters in the model.
716
-
717
- bb : relax.BlockBuilder
718
- The Relax BlockBuilder used for inserting the dequantization computations.
719
-
720
- func_name : str
721
- The name of the function which dequantization is applied to.
722
-
723
723
qparams : List[relax.Var]
724
- The quantized parts of the parameter.
725
- By default it is `None`, in which case we will get the quantized parts
726
- from `quantized_tuple`.
724
+ The relax.Var of the quantized tensors of all parameters in the model.
727
725
728
726
Returns
729
727
-------
730
728
The dequantized parameter, in the form of a relax.Var.
731
729
"""
732
- if not qparams :
733
- # Get the corresponding Relax vars of the quantized tensors of this parameter.
734
- qparams : List [relax .Var ] = []
735
- for qparam_idx in self .param2qrange [param ]:
736
- qparams .append (bb .emit (relax .TupleGetItem (quantized_tuple , qparam_idx )))
737
-
738
730
# Get the dequantization function of this parameter.
739
731
f_dequantize = param .quant_spec .get_dequantize_func (
740
732
param_info = param .param_info_dict [func_name ],
@@ -789,7 +781,7 @@ class ParamReplacer(PyExprMutator):
789
781
mod : tvm.IRModule
790
782
The IRModule of the model to be updated.
791
783
792
- func2param_var : Dict[str, relax.Var]
784
+ func_name_to_quantized_params : Dict[str, List[ relax.Var] ]
793
785
The mapping from each function name to its input var of quantized data tuple.
794
786
795
787
f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var]
@@ -801,7 +793,7 @@ class ParamReplacer(PyExprMutator):
801
793
"""
802
794
803
795
mod : tvm .IRModule
804
- func2param_var : Dict [str , relax .Var ]
796
+ func_name_to_quantized_params : Dict [str , List [ relax .Var ] ]
805
797
f_replace : Callable [[relax .Var , relax .BlockBuilder ], relax .Var ]
806
798
param_set : Set [relax .Var ]
807
799
@@ -810,12 +802,12 @@ class ParamReplacer(PyExprMutator):
810
802
def __init__ (
811
803
self ,
812
804
mod : tvm .IRModule ,
813
- func2param_var : Dict [str , relax .Var ],
805
+ func_name_to_quantized_params : Dict [str , relax .Var ],
814
806
f_replace : Callable [[relax .Var , relax .BlockBuilder ], relax .Var ],
815
807
):
816
808
super ().__init__ (mod )
817
809
self .mod = mod
818
- self .func2param_var = func2param_var
810
+ self .func_name_to_quantized_params = func_name_to_quantized_params
819
811
self .f_replace = f_replace
820
812
self .cur_func_name = ""
821
813
@@ -827,31 +819,31 @@ def transform(self) -> tvm.IRModule:
827
819
continue
828
820
829
821
assert (
830
- gv .name_hint in self .func2param_var
831
- ), f"{ gv .name_hint } not in { self .func2param_var } "
832
- self .cur_func_name = gv .name_hint
833
- updated_func = self .rewrite_func (func , self .func2param_var [gv .name_hint ])
822
+ gv .name_hint in self .func_name_to_quantized_params
823
+ ), f"{ gv .name_hint } not in { self .func_name_to_quantized_params } "
824
+ updated_func = self .rewrite_func (func , self .func_name_to_quantized_params [gv .name_hint ])
834
825
updated_func = remove_all_unused (updated_func )
835
826
self .builder_ .update_func (gv , updated_func )
836
827
return self .builder_ .get ()
837
828
838
- def rewrite_func (self , func : Function , param_var : relax .Var ) -> relax .Function :
829
+ def rewrite_func (self , func : Function , quantized_params : List [ relax .Var ] ) -> relax .Function :
839
830
num_input = int (func .attrs ["num_input" ])
840
831
self .param_set = set (func .params [num_input :])
841
832
842
833
body = self .visit_expr (func .body )
843
834
return relax .Function (
844
- params = func .params [:num_input ] + [ param_var ] ,
835
+ params = func .params [:num_input ] + quantized_params ,
845
836
body = body ,
846
837
ret_struct_info = func .ret_struct_info ,
847
838
is_pure = func .is_pure ,
848
839
attrs = func .attrs ,
849
- ). without_attr ( "num_input" )
840
+ )
850
841
851
842
def visit_var_ (self , var : Var ) -> Expr :
852
- if var not in self .param_set :
843
+ if var in self .param_set :
844
+ return self .f_replace (var , self .builder_ )
845
+ else :
853
846
return super ().visit_var_ (var )
854
- return self .f_replace (var , self .builder_ , self .cur_func_name )
855
847
856
848
857
849
##################################################################
0 commit comments