Skip to content

Integrate AutoRound v0.3 #1925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 86 additions & 48 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,69 +31,95 @@ class AutoRoundQuantizer(Quantizer):
def __init__(
self,
quant_config: dict = {},
enable_full_range: bool = False,
enable_full_range: bool = False, ##for symmetric, TODO support later
batch_size: int = 8,
amp: bool = True,
device=None,
device: str = None,
lr_scheduler=None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
low_gpu_mem_usage: bool = False,
iters: int = 200,
seqlen: int = 2048,
n_samples: int = 512,
nsamples: int = 128,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
nblocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
data_type: str = "int",
scale_dtype: str = "fp16",
multimodal: bool = False,
act_bits: int = 32,
act_group_size: int = None,
act_sym: bool = None,
act_dynamic: bool = True,
low_cpu_mem_usage: bool = False,
**kwargs,
):
"""Init a AutQRoundQuantizer object.

Args:
quant_config (dict): Configuration for weight quantization (default is None).
quant_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
quant_config (dict): Configuration for weight quantization (default is None).
quant_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False,
'act_data_type': None,
'act_bits': 32,
'act_sym': None,
'act_dynamic': True,
}
...,
}
...
}
keys:
data_type (str): The data type to be used (default is "int").
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
device: The device to be used for tuning (default is None). Automatically detect and set.
lr_scheduler: The learning rate scheduler to be used.
use_quant_input (bool): Whether to use quantized input data (default is True).
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
keys:
data_type (str): The data type to be used (default is "int").
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether to use symmetric quantization. (default is None).
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether symmetric quantization is to be used (default is False).
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for tuning (default is "auto").
lr_scheduler: The learning rate scheduler to be used.
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
the input for the current block (default is True).
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
lr (float): The learning rate (default is None, will be set to 1.0/iters).
minmax_lr (float): The learning rate for min-max tuning
(default is None, it will be set to lr automatically).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Data length of the sequence for tuning (default is 2048).
nsamples (int): Number of samples (default is 128).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
nblocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
multimodal(bool): Enable multimodal model quantization, (default is "False").
act_bits (int): Number of bits for activation quantization. Default is 32.
act_group_size (int): Group size for activation quantization. Default is None.
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.

Returns:
The quantized model.
"""
super().__init__(quant_config)
self.tokenizer = None
Expand All @@ -109,15 +135,21 @@ def __init__(
self.low_gpu_mem_usage = low_gpu_mem_usage
self.iters = iters
self.seqlen = seqlen
self.n_samples = n_samples
self.nsamples = nsamples
self.sampler = sampler
self.seed = seed
self.n_blocks = n_blocks
self.nblocks = nblocks
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.data_type = data_type
self.scale_dtype = scale_dtype
self.multimodal = multimodal
self.act_bits = act_bits
self.act_group_size = act_group_size
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.low_cpu_mem_usage = low_cpu_mem_usage

def prepare(self, model: torch.nn.Module, *args, **kwargs):
"""Prepares a given model for quantization.
Expand All @@ -137,7 +169,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
model=model,
tokenizer=None,
dataset=dataloader,
weight_config=self.quant_config or {},
layer_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
Expand All @@ -150,23 +182,29 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
low_gpu_mem_usage=self.low_gpu_mem_usage,
iters=self.iters,
seqlen=self.seqlen,
n_samples=self.n_samples,
nsamples=self.nsamples,
sampler=self.sampler,
seed=self.seed,
n_blocks=self.n_blocks,
nblocks=self.nblocks,
gradient_accumulate_steps=self.gradient_accumulate_steps,
not_use_best_mse=self.not_use_best_mse,
dynamic_max_gap=self.dynamic_max_gap,
data_type=self.data_type,
scale_dtype=self.scale_dtype,
multimodal=self.multimodal,
act_bits=self.act_bits,
act_group_size=self.act_group_size,
act_sym=self.act_sym,
act_dynamic=self.act_dynamic,
low_cpu_mem_usage=self.low_cpu_mem_usage,
)
model, weight_config = rounder.quantize()
model.autoround_config = weight_config
model = pack_model(model, weight_config, device=self.device, inplace=True)
return model


def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=128):
"""Generate a DataLoader for calibration using specified parameters.

Args:
Expand All @@ -186,6 +224,6 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

dataloader = get_dataloader(
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, nsamples=nsamples
)
return dataloader
16 changes: 12 additions & 4 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def autoround_quantize_entry(
"bits": quant_config.bits,
"sym": quant_config.use_sym,
"group_size": quant_config.group_size,
"act_bits": quant_config.act_bits,
"act_group_size": quant_config.act_group_size,
"act_sym": quant_config.act_sym,
"act_dynamic": quant_config.act_dynamic,
}
enable_full_range = quant_config.enable_full_range
batch_size = quant_config.batch_size
Expand All @@ -478,14 +482,16 @@ def autoround_quantize_entry(
low_gpu_mem_usage = quant_config.low_gpu_mem_usage
iters = quant_config.iters
seqlen = quant_config.seqlen
n_samples = quant_config.n_samples
nsamples = quant_config.nsamples
sampler = quant_config.sampler
seed = quant_config.seed
n_blocks = quant_config.n_blocks
nblocks = quant_config.nblocks
gradient_accumulate_steps = quant_config.gradient_accumulate_steps
not_use_best_mse = quant_config.not_use_best_mse
dynamic_max_gap = quant_config.dynamic_max_gap
scale_dtype = quant_config.scale_dtype
multimodal = quant_config.multimodal
low_cpu_mem_usage = quant_config.use_layer_wise

kwargs.pop("example_inputs")

Expand All @@ -503,14 +509,16 @@ def autoround_quantize_entry(
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
nsamples=nsamples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
nblocks=nblocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
scale_dtype=scale_dtype,
multimodal=multimodal,
low_cpu_mem_usage=low_cpu_mem_usage,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
model.qconfig = configs_mapping
Expand Down
36 changes: 26 additions & 10 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ class AutoRoundConfig(TorchBaseConfig):
"minmax_lr",
"iters",
"seqlen",
"n_samples",
"n_blocks",
"nsamples",
"nblocks",
"gradient_accumulate_steps",
"not_use_best_mse",
"dynamic_max_gap",
Expand All @@ -746,6 +746,10 @@ def __init__(
use_sym: bool = False,
group_size: int = 128,
# AUTOROUND
act_bits: int = 32,
act_group_size: int = None,
act_sym: bool = None,
act_dynamic: bool = True,
enable_full_range: bool = False,
batch_size: int = 8,
lr_scheduler=None,
Expand All @@ -755,16 +759,17 @@ def __init__(
minmax_lr: float = None,
low_gpu_mem_usage: bool = True,
iters: int = 200,
seqlen: int = 512,
n_samples: int = 512,
seqlen: int = 2048,
nsamples: int = 128,
sampler: str = "rand",
seed: int = 42,
n_blocks: int = 1,
nblocks: int = 1,
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
scale_dtype: str = "fp16",
use_layer_wise: bool = False,
multimodal: bool = False,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
"""Init AUTOROUND weight-only quantization config.
Expand All @@ -774,6 +779,10 @@ def __init__(
bits (int): Number of bits used to represent weights, default is 4.
use_sym (bool): Indicates whether weights are symmetric, default is False.
group_size (int): Size of weight groups, default is 128.
act_bits (int): Number of bits for activation quantization. Default is 32.
act_group_size (int): Group size for activation quantization. Default is None.
act_sym (bool): Whether to use symmetric activation quantization. Default is None.
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
lr_scheduler: The learning rate scheduler to be used.
Expand All @@ -784,21 +793,27 @@ def __init__(
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
nsamples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
nblocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels
have different choices.
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
multimodal(bool): Enable multimodal model quantization, (default is "False").
"""
super().__init__(white_list=white_list)
self.dtype = dtype
self.bits = bits
self.use_sym = use_sym
self.group_size = group_size
self.act_bits = act_bits
self.act_group_size = act_group_size
self.act_sym = act_sym
self.act_dynamic = act_dynamic
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.lr_scheduler = lr_scheduler
Expand All @@ -809,15 +824,16 @@ def __init__(
self.low_gpu_mem_usage = low_gpu_mem_usage
self.iters = iters
self.seqlen = seqlen
self.n_samples = n_samples
self.nsamples = nsamples
self.sampler = sampler
self.seed = seed
self.n_blocks = n_blocks
self.nblocks = nblocks
self.gradient_accumulate_steps = gradient_accumulate_steps
self.not_use_best_mse = not_use_best_mse
self.dynamic_max_gap = dynamic_max_gap
self.scale_dtype = scale_dtype
self.use_layer_wise = use_layer_wise
self.multimodal = multimodal
self._post_init()

@classmethod
Expand Down Expand Up @@ -1522,7 +1538,7 @@ def get_woq_tuning_config() -> list:
the list of WOQ quant config.
"""
RTN_G32ASYM = RTNConfig(use_sym=False, group_size=32)
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32)
AUTO_ROUND_CONFIG = AutoRoundConfig(use_sym=False, group_size=32, seqlen=512)
GPTQ_G32ASYM = GPTQConfig(use_sym=False, group_size=32)
AWQ_G32ASYM = AWQConfig(use_sym=False, group_size=32)
return [RTN_G32ASYM, AUTO_ROUND_CONFIG, GPTQ_G32ASYM, AWQ_G32ASYM]
Loading
Loading