Skip to content

Commit bfa27e4

Browse files
authored
Integrate AutoRound v0.3 (#1925)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent 5767aed commit bfa27e4

File tree

5 files changed

+130
-68
lines changed

5 files changed

+130
-68
lines changed

neural_compressor/torch/algorithms/weight_only/autoround.py

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -31,69 +31,95 @@ class AutoRoundQuantizer(Quantizer):
3131
def __init__(
3232
self,
3333
quant_config: dict = {},
34-
enable_full_range: bool = False,
34+
enable_full_range: bool = False, ##for symmetric, TODO support later
3535
batch_size: int = 8,
3636
amp: bool = True,
37-
device=None,
37+
device: str = None,
3838
lr_scheduler=None,
39+
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
3940
enable_quanted_input: bool = True,
4041
enable_minmax_tuning: bool = True,
4142
lr: float = None,
4243
minmax_lr: float = None,
43-
low_gpu_mem_usage: bool = True,
44+
low_gpu_mem_usage: bool = False,
4445
iters: int = 200,
4546
seqlen: int = 2048,
46-
n_samples: int = 512,
47+
nsamples: int = 128,
4748
sampler: str = "rand",
4849
seed: int = 42,
49-
n_blocks: int = 1,
50+
nblocks: int = 1,
5051
gradient_accumulate_steps: int = 1,
5152
not_use_best_mse: bool = False,
5253
dynamic_max_gap: int = -1,
5354
data_type: str = "int",
5455
scale_dtype: str = "fp16",
56+
multimodal: bool = False,
57+
act_bits: int = 32,
58+
act_group_size: int = None,
59+
act_sym: bool = None,
60+
act_dynamic: bool = True,
61+
low_cpu_mem_usage: bool = False,
5562
**kwargs,
5663
):
5764
"""Init a AutQRoundQuantizer object.
5865
5966
Args:
60-
quant_config (dict): Configuration for weight quantization (default is None).
61-
quant_config={
62-
'layer1':##layer_name
63-
{
64-
'data_type': 'int',
65-
'bits': 4,
66-
'group_size': 32,
67-
'sym': False,
67+
quant_config (dict): Configuration for weight quantization (default is None).
68+
quant_config={
69+
'layer1':##layer_name
70+
{
71+
'data_type': 'int',
72+
'bits': 4,
73+
'group_size': 32,
74+
'sym': False,
75+
'act_data_type': None,
76+
'act_bits': 32,
77+
'act_sym': None,
78+
'act_dynamic': True,
79+
}
80+
...,
6881
}
69-
...
70-
}
71-
keys:
72-
data_type (str): The data type to be used (default is "int").
73-
bits (int): Number of bits for quantization (default is 4).
74-
group_size (int): Size of the quantization group (default is 128).
75-
sym (bool): Whether to use symmetric quantization. (default is None).
76-
enable_full_range (bool): Whether to enable full range quantization (default is False).
77-
batch_size (int): Batch size for training (default is 8).
78-
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
79-
device: The device to be used for tuning (default is None). Automatically detect and set.
80-
lr_scheduler: The learning rate scheduler to be used.
81-
use_quant_input (bool): Whether to use quantized input data (default is True).
82-
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
83-
lr (float): The learning rate (default is 0.005).
84-
minmax_lr (float): The learning rate for min-max tuning (default is None).
85-
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
86-
iters (int): Number of iterations (default is 200).
87-
seqlen (int): Length of the sequence.
88-
n_samples (int): Number of samples (default is 512).
89-
sampler (str): The sampling method (default is "rand").
90-
seed (int): The random seed (default is 42).
91-
n_blocks (int): Number of blocks (default is 1).
92-
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
93-
not_use_best_mse (bool): Whether to use mean squared error (default is False).
94-
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
95-
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
96-
have different choices.
82+
keys:
83+
data_type (str): The data type to be used (default is "int").
84+
bits (int): Number of bits for quantization (default is 4).
85+
group_size (int): Size of the quantization group (default is 128).
86+
sym (bool): Whether to use symmetric quantization. (default is None).
87+
bits (int): Number of bits for quantization (default is 4).
88+
group_size (int): Size of the quantization group (default is 128).
89+
sym (bool): Whether symmetric quantization is to be used (default is False).
90+
enable_full_range (bool): Whether to enable full range quantization (default is False).
91+
batch_size (int): Batch size for training (default is 8).
92+
amp (bool): Whether to use automatic mixed precision (default is True).
93+
device: The device to be used for tuning (default is "auto").
94+
lr_scheduler: The learning rate scheduler to be used.
95+
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
96+
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
97+
the input for the current block (default is True).
98+
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
99+
lr (float): The learning rate (default is None, will be set to 1.0/iters).
100+
minmax_lr (float): The learning rate for min-max tuning
101+
(default is None, it will be set to lr automatically).
102+
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
103+
iters (int): Number of iterations (default is 200).
104+
seqlen (int): Data length of the sequence for tuning (default is 2048).
105+
nsamples (int): Number of samples (default is 128).
106+
sampler (str): The sampling method (default is "rand").
107+
seed (int): The random seed (default is 42).
108+
nblocks (int): Number of blocks (default is 1).
109+
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
110+
not_use_best_mse (bool): Whether to use mean squared error (default is False).
111+
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
112+
data_type (str): The data type to be used (default is "int").
113+
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
114+
have different choices.
115+
multimodal(bool): Enable multimodal model quantization, (default is "False").
116+
act_bits (int): Number of bits for activation quantization. Default is 32.
117+
act_group_size (int): Group size for activation quantization. Default is None.
118+
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
119+
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
120+
121+
Returns:
122+
The quantized model.
97123
"""
98124
super().__init__(quant_config)
99125
self.tokenizer = None
@@ -109,15 +135,21 @@ def __init__(
109135
self.low_gpu_mem_usage = low_gpu_mem_usage
110136
self.iters = iters
111137
self.seqlen = seqlen
112-
self.n_samples = n_samples
138+
self.nsamples = nsamples
113139
self.sampler = sampler
114140
self.seed = seed
115-
self.n_blocks = n_blocks
141+
self.nblocks = nblocks
116142
self.gradient_accumulate_steps = gradient_accumulate_steps
117143
self.not_use_best_mse = not_use_best_mse
118144
self.dynamic_max_gap = dynamic_max_gap
119145
self.data_type = data_type
120146
self.scale_dtype = scale_dtype
147+
self.multimodal = multimodal
148+
self.act_bits = act_bits
149+
self.act_group_size = act_group_size
150+
self.act_sym = act_sym
151+
self.act_dynamic = act_dynamic
152+
self.low_cpu_mem_usage = low_cpu_mem_usage
121153

122154
def prepare(self, model: torch.nn.Module, *args, **kwargs):
123155
"""Prepares a given model for quantization.
@@ -137,7 +169,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
137169
model=model,
138170
tokenizer=None,
139171
dataset=dataloader,
140-
weight_config=self.quant_config or {},
172+
layer_config=self.quant_config or {},
141173
enable_full_range=self.enable_full_range,
142174
batch_size=self.batch_size,
143175
amp=self.amp,
@@ -150,23 +182,29 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
150182
low_gpu_mem_usage=self.low_gpu_mem_usage,
151183
iters=self.iters,
152184
seqlen=self.seqlen,
153-
n_samples=self.n_samples,
185+
nsamples=self.nsamples,
154186
sampler=self.sampler,
155187
seed=self.seed,
156-
n_blocks=self.n_blocks,
188+
nblocks=self.nblocks,
157189
gradient_accumulate_steps=self.gradient_accumulate_steps,
158190
not_use_best_mse=self.not_use_best_mse,
159191
dynamic_max_gap=self.dynamic_max_gap,
160192
data_type=self.data_type,
161193
scale_dtype=self.scale_dtype,
194+
multimodal=self.multimodal,
195+
act_bits=self.act_bits,
196+
act_group_size=self.act_group_size,
197+
act_sym=self.act_sym,
198+
act_dynamic=self.act_dynamic,
199+
low_cpu_mem_usage=self.low_cpu_mem_usage,
162200
)
163201
model, weight_config = rounder.quantize()
164202
model.autoround_config = weight_config
165203
model = pack_model(model, weight_config, device=self.device, inplace=True)
166204
return model
167205

168206

169-
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
207+
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=128):
170208
"""Generate a DataLoader for calibration using specified parameters.
171209
172210
Args:
@@ -186,6 +224,6 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
186224
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401
187225

188226
dataloader = get_dataloader(
189-
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
227+
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples
190228
)
191229
return dataloader

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,10 @@ def autoround_quantize_entry(
572572
"bits": quant_config.bits,
573573
"sym": quant_config.use_sym,
574574
"group_size": quant_config.group_size,
575+
"act_bits": quant_config.act_bits,
576+
"act_group_size": quant_config.act_group_size,
577+
"act_sym": quant_config.act_sym,
578+
"act_dynamic": quant_config.act_dynamic,
575579
}
576580
enable_full_range = quant_config.enable_full_range
577581
batch_size = quant_config.batch_size
@@ -583,14 +587,16 @@ def autoround_quantize_entry(
583587
low_gpu_mem_usage = quant_config.low_gpu_mem_usage
584588
iters = quant_config.iters
585589
seqlen = quant_config.seqlen
586-
n_samples = quant_config.n_samples
590+
nsamples = quant_config.nsamples
587591
sampler = quant_config.sampler
588592
seed = quant_config.seed
589-
n_blocks = quant_config.n_blocks
593+
nblocks = quant_config.nblocks
590594
gradient_accumulate_steps = quant_config.gradient_accumulate_steps
591595
not_use_best_mse = quant_config.not_use_best_mse
592596
dynamic_max_gap = quant_config.dynamic_max_gap
593597
scale_dtype = quant_config.scale_dtype
598+
multimodal = quant_config.multimodal
599+
low_cpu_mem_usage = quant_config.use_layer_wise
594600

595601
kwargs.pop("example_inputs")
596602

@@ -608,14 +614,16 @@ def autoround_quantize_entry(
608614
low_gpu_mem_usage=low_gpu_mem_usage,
609615
iters=iters,
610616
seqlen=seqlen,
611-
n_samples=n_samples,
617+
nsamples=nsamples,
612618
sampler=sampler,
613619
seed=seed,
614-
n_blocks=n_blocks,
620+
nblocks=nblocks,
615621
gradient_accumulate_steps=gradient_accumulate_steps,
616622
not_use_best_mse=not_use_best_mse,
617623
dynamic_max_gap=dynamic_max_gap,
618624
scale_dtype=scale_dtype,
625+
multimodal=multimodal,
626+
low_cpu_mem_usage=low_cpu_mem_usage,
619627
)
620628
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
621629
model.qconfig = configs_mapping

neural_compressor/torch/quantization/config.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -735,8 +735,8 @@ class AutoRoundConfig(TorchBaseConfig):
735735
"minmax_lr",
736736
"iters",
737737
"seqlen",
738-
"n_samples",
739-
"n_blocks",
738+
"nsamples",
739+
"nblocks",
740740
"gradient_accumulate_steps",
741741
"not_use_best_mse",
742742
"dynamic_max_gap",
@@ -750,6 +750,10 @@ def __init__(
750750
use_sym: bool = False,
751751
group_size: int = 128,
752752
# AUTOROUND
753+
act_bits: int = 32,
754+
act_group_size: int = None,
755+
act_sym: bool = None,
756+
act_dynamic: bool = True,
753757
enable_full_range: bool = False,
754758
batch_size: int = 8,
755759
lr_scheduler=None,
@@ -759,16 +763,17 @@ def __init__(
759763
minmax_lr: float = None,
760764
low_gpu_mem_usage: bool = True,
761765
iters: int = 200,
762-
seqlen: int = 512,
763-
n_samples: int = 512,
766+
seqlen: int = 2048,
767+
nsamples: int = 128,
764768
sampler: str = "rand",
765769
seed: int = 42,
766-
n_blocks: int = 1,
770+
nblocks: int = 1,
767771
gradient_accumulate_steps: int = 1,
768772
not_use_best_mse: bool = False,
769773
dynamic_max_gap: int = -1,
770774
scale_dtype: str = "fp16",
771775
use_layer_wise: bool = False,
776+
multimodal: bool = False,
772777
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
773778
):
774779
"""Init AUTOROUND weight-only quantization config.
@@ -778,6 +783,10 @@ def __init__(
778783
bits (int): Number of bits used to represent weights, default is 4.
779784
use_sym (bool): Indicates whether weights are symmetric, default is False.
780785
group_size (int): Size of weight groups, default is 128.
786+
act_bits (int): Number of bits for activation quantization. Default is 32.
787+
act_group_size (int): Group size for activation quantization. Default is None.
788+
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
789+
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
781790
enable_full_range (bool): Whether to enable full range quantization (default is False).
782791
batch_size (int): Batch size for training (default is 8).
783792
lr_scheduler: The learning rate scheduler to be used.
@@ -788,21 +797,27 @@ def __init__(
788797
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
789798
iters (int): Number of iterations (default is 200).
790799
seqlen (int): Length of the sequence.
791-
n_samples (int): Number of samples (default is 512).
800+
nsamples (int): Number of samples (default is 512).
792801
sampler (str): The sampling method (default is "rand").
793802
seed (int): The random seed (default is 42).
794-
n_blocks (int): Number of blocks (default is 1).
803+
nblocks (int): Number of blocks (default is 1).
795804
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
796805
not_use_best_mse (bool): Whether to use mean squared error (default is False).
797806
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
798807
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
799808
have different choices.
809+
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
810+
multimodal(bool): Enable multimodal model quantization, (default is "False").
800811
"""
801812
super().__init__(white_list=white_list)
802813
self.dtype = dtype
803814
self.bits = bits
804815
self.use_sym = use_sym
805816
self.group_size = group_size
817+
self.act_bits = act_bits
818+
self.act_group_size = act_group_size
819+
self.act_sym = act_sym
820+
self.act_dynamic = act_dynamic
806821
self.enable_full_range = enable_full_range
807822
self.batch_size = batch_size
808823
self.lr_scheduler = lr_scheduler
@@ -813,15 +828,16 @@ def __init__(
813828
self.low_gpu_mem_usage = low_gpu_mem_usage
814829
self.iters = iters
815830
self.seqlen = seqlen
816-
self.n_samples = n_samples
831+
self.nsamples = nsamples
817832
self.sampler = sampler
818833
self.seed = seed
819-
self.n_blocks = n_blocks
834+
self.nblocks = nblocks
820835
self.gradient_accumulate_steps = gradient_accumulate_steps
821836
self.not_use_best_mse = not_use_best_mse
822837
self.dynamic_max_gap = dynamic_max_gap
823838
self.scale_dtype = scale_dtype
824839
self.use_layer_wise = use_layer_wise
840+
self.multimodal = multimodal
825841
self._post_init()
826842

827843
@classmethod
@@ -1526,7 +1542,7 @@ def get_woq_tuning_config() -> list:
15261542
the list of WOQ quant config.
15271543
"""
15281544
RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32)
1529-
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32)
1545+
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32, seqlen=512)
15301546
GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32)
15311547
AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32)
15321548
return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM]

0 commit comments

Comments
 (0)