@@ -682,7 +682,7 @@ def autoround_quantize(
682
682
enable_full_range : bool = False , ##for symmetric, TODO support later
683
683
bs : int = 8 ,
684
684
amp : bool = True ,
685
- device = "cuda:0" ,
685
+ device = None ,
686
686
lr_scheduler = None ,
687
687
dataloader = None , ## to support later
688
688
dataset_name : str = "NeelNanda/pile-10k" ,
@@ -703,7 +703,6 @@ def autoround_quantize(
703
703
dynamic_max_gap : int = - 1 ,
704
704
data_type : str = "int" , ##only support data_type
705
705
scale_dtype = "fp16" ,
706
- export_args : dict = {"format" : None , "inplace" : True },
707
706
** kwargs ,
708
707
):
709
708
"""Run autoround weight-only quantization.
@@ -726,8 +725,8 @@ def autoround_quantize(
726
725
}
727
726
enable_full_range (bool): Whether to enable full range quantization (default is False).
728
727
bs (int): Batch size for training (default is 8).
729
- amp (bool): Whether to use automatic mixed precision (default is True).
730
- device: The device to be used for tuning (default is "cuda:0") .
728
+ amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
729
+ device: The device to be used for tuning (default is None). Automatically detect and set .
731
730
lr_scheduler: The learning rate scheduler to be used.
732
731
dataloader: The dataloader for input data (to be supported in future).
733
732
dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k").
@@ -747,8 +746,6 @@ def autoround_quantize(
747
746
not_use_best_mse (bool): Whether to use mean squared error (default is False).
748
747
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
749
748
data_type (str): The data type to be used (default is "int").
750
- export_args (dict): The arguments for exporting compressed model, default is {"format": None, "inplace": True}.
751
- Supported format: "itrex", "auto_gptq".
752
749
**kwargs: Additional keyword arguments.
753
750
754
751
Returns:
@@ -790,11 +787,4 @@ def autoround_quantize(
790
787
** kwargs ,
791
788
)
792
789
qdq_model , weight_config = rounder .quantize ()
793
- if export_args ["format" ] is not None :
794
- output_dir = export_args .get ("output_dir" , None )
795
- format = export_args ["format" ]
796
- inplace = export_args .get ("inplace" , True )
797
- use_triton = export_args .get ("use_triton" , False )
798
- model = rounder .save_quantized (output_dir = output_dir , format = format , inplace = inplace , use_triton = use_triton )
799
- return model , weight_config
800
790
return qdq_model , weight_config
0 commit comments