Skip to content

Commit 5c7f336

Browse files
authored
Intergrate AutoRound (#1619)
Signed-off-by: Kaihui-intel <[email protected]> Signed-off-by: chensuyue <[email protected]>
1 parent 354791d commit 5c7f336

File tree

9 files changed

+384
-3
lines changed

9 files changed

+384
-3
lines changed

.azure-pipelines/scripts/ut/env_setup.sh

+5
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
9797
# so test distribute cases in the env with single fw installed
9898
pip install horovod
9999
fi
100+
101+
if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
102+
pip install git+https://github.com/intel/auto-round.git@a868c805de4be271cfe7403309a64d9bf03a0ecf
103+
fi
104+
100105
# test deps
101106
pip install coverage
102107
pip install pytest

neural_compressor/adaptor/pytorch.py

+90
Original file line numberDiff line numberDiff line change
@@ -4615,6 +4615,9 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
46154615
q_model._model = self.awq_quantize(q_model._model, tune_cfg, dataloader, calib_func)
46164616
if "RTN" in all_algo:
46174617
q_model._model = self.rtn_quantize(q_model._model, tune_cfg)
4618+
if "AUTOROUND" in all_algo:
4619+
q_model._model, autoround_config = self.autoround_quantize(q_model._model, tune_cfg, dataloader)
4620+
q_model.autoround_config = autoround_config
46184621

46194622
q_model.q_config = copy.deepcopy(self.tune_cfg)
46204623
q_model.is_quantized = True
@@ -4911,6 +4914,93 @@ def awq_quantize(self, model, tune_cfg, dataloader, calib_func):
49114914
)
49124915
return model
49134916

4917+
def autoround_quantize(self, model, tune_cfg, dataloader):
4918+
logger.info("quantizing with the AutoRound algorithm")
4919+
from .torch_utils.weight_only import autoround_quantize
4920+
4921+
# build weight_config
4922+
"""
4923+
weight_config={
4924+
'layer1':##layer_name
4925+
{
4926+
'data_type': 'int',
4927+
'bits': 4,
4928+
'group_size': 32,
4929+
'scheme': "asym", ## or sym
4930+
}
4931+
...
4932+
}
4933+
"""
4934+
weight_config = {}
4935+
for key, config in tune_cfg["op"].items():
4936+
if config["weight"]["dtype"] == "fp32":
4937+
continue
4938+
op_name, op_type = key
4939+
weight_config[op_name] = {}
4940+
weight_config[op_name]["data_type"] = config["weight"]["dtype"]
4941+
weight_config[op_name]["bits"] = config["weight"]["bits"]
4942+
weight_config[op_name]["group_size"] = config["weight"]["group_size"]
4943+
weight_config[op_name]["scheme"] = config["weight"]["scheme"]
4944+
4945+
# auto round recipes
4946+
enable_full_range = self.recipes["autoround_args"].get("enable_full_range", False)
4947+
bs = self.recipes["autoround_args"].get("bs", 8)
4948+
amp = self.recipes["autoround_args"].get("amp", True)
4949+
device = self.recipes["autoround_args"].get("device", "cpu")
4950+
lr_scheduler = self.recipes["autoround_args"].get("lr_scheduler", None)
4951+
dataset_name = self.recipes["autoround_args"].get("dataset_name", "NeelNanda/pile-10k")
4952+
dataset_split = self.recipes["autoround_args"].get("dataset_split", "train")
4953+
use_quant_input = self.recipes["autoround_args"].get("use_quant_input", True)
4954+
enable_minmax_tuning = self.recipes["autoround_args"].get("enable_minmax_tuning", True)
4955+
lr = self.recipes["autoround_args"].get("lr", None)
4956+
minmax_lr = self.recipes["autoround_args"].get("minmax_lr", None)
4957+
low_gpu_mem_usage = self.recipes["autoround_args"].get("low_gpu_mem_usage", True)
4958+
iters = self.recipes["autoround_args"].get("iters", 200)
4959+
seqlen = self.recipes["autoround_args"].get("seqlen", 2048)
4960+
n_samples = self.recipes["autoround_args"].get("n_samples", 512)
4961+
sampler = self.recipes["autoround_args"].get("sampler", "rand")
4962+
seed = self.recipes["autoround_args"].get("seed", 42)
4963+
n_blocks = self.recipes["autoround_args"].get("n_blocks", 1)
4964+
gradient_accumulate_steps = self.recipes["autoround_args"].get("gradient_accumulate_steps", 1)
4965+
not_use_best_mse = self.recipes["autoround_args"].get("not_use_best_mse", False)
4966+
dynamic_max_gap = self.recipes["autoround_args"].get("dynamic_max_gap", -1)
4967+
data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type
4968+
scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16")
4969+
4970+
model, autoround_config = autoround_quantize(
4971+
model=model,
4972+
tokenizer=None,
4973+
bits=4,
4974+
group_size=128,
4975+
scheme="asym",
4976+
weight_config=weight_config,
4977+
enable_full_range=enable_full_range,
4978+
bs=bs,
4979+
amp=amp,
4980+
device=device,
4981+
lr_scheduler=lr_scheduler,
4982+
dataloader=dataloader,
4983+
dataset_name=dataset_name,
4984+
dataset_split=dataset_split,
4985+
use_quant_input=use_quant_input,
4986+
enable_minmax_tuning=enable_minmax_tuning,
4987+
lr=lr,
4988+
minmax_lr=minmax_lr,
4989+
low_gpu_mem_usage=low_gpu_mem_usage,
4990+
iters=iters,
4991+
seqlen=seqlen,
4992+
n_samples=n_samples,
4993+
sampler=sampler,
4994+
seed=seed,
4995+
n_blocks=n_blocks,
4996+
gradient_accumulate_steps=gradient_accumulate_steps,
4997+
not_use_best_mse=not_use_best_mse,
4998+
dynamic_max_gap=dynamic_max_gap,
4999+
data_type=data_type,
5000+
scale_dtype=scale_dtype,
5001+
)
5002+
return model, autoround_config
5003+
49145004
def _dump_model_op_stats(self, model, tune_cfg):
49155005
"""This is a function to dump quantizable ops of model to user.
49165006

neural_compressor/adaptor/pytorch_cpu.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@
267267
# group_size=-1 means per-channel, others means per-group
268268
'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32
269269
'scheme': ['sym', 'asym'], # sym, no ZP
270-
'algorithm': ['RTN', 'AWQ', 'GPTQ', 'TEQ'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order
270+
'algorithm': ['RTN', 'AWQ', 'GPTQ', 'TEQ', 'AUTOROUND'], # RTN, [RTN, GPTQ, AWQ,] RTN+AWQ+TEQ order
271271
},
272272
'activation': {
273273
'dtype': ['fp32'],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
16+
17+
18+
def get_dataloader(
19+
tokenizer, seqlen=2048, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k"
20+
):
21+
get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"])
22+
dataloader = get_dataloader(
23+
tokenizer, seqlen=seqlen, seed=seed, bs=train_bs, split=dataset_split, dataset_name=dataset_name
24+
)
25+
return dataloader

neural_compressor/adaptor/torch_utils/weight_only.py

+118
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,121 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
670670
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
671671
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
672672
return int_weight
673+
674+
675+
def autoround_quantize(
676+
model,
677+
tokenizer,
678+
bits: int = 4,
679+
group_size: int = 128,
680+
scheme: str = "asym",
681+
weight_config: dict = {},
682+
enable_full_range: bool = False, ##for symmetric, TODO support later
683+
bs: int = 8,
684+
amp: bool = True,
685+
device="cuda:0",
686+
lr_scheduler=None,
687+
dataloader=None, ## to support later
688+
dataset_name: str = "NeelNanda/pile-10k",
689+
dataset_split: str = "train",
690+
use_quant_input: bool = True,
691+
enable_minmax_tuning: bool = True,
692+
lr: float = None,
693+
minmax_lr: float = None,
694+
low_gpu_mem_usage: bool = True,
695+
iters: int = 200,
696+
seqlen: int = 2048,
697+
n_samples: int = 512,
698+
sampler: str = "rand",
699+
seed: int = 42,
700+
n_blocks: int = 1,
701+
gradient_accumulate_steps: int = 1,
702+
not_use_best_mse: bool = False,
703+
dynamic_max_gap: int = -1,
704+
data_type: str = "int", ##only support data_type
705+
scale_dtype="fp16",
706+
**kwargs,
707+
):
708+
"""Run autoround weight-only quantization.
709+
Args:
710+
model: The PyTorch model to be quantized.
711+
tokenizer: Tokenizer for processing input data. Temporarily set as a mandatory parameter.
712+
bits (int): Number of bits for quantization (default is 4).
713+
group_size (int): Size of the quantization group (default is 128).
714+
scheme (str): The quantization scheme to be used (default is "asym").
715+
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
716+
weight_config={
717+
'layer1':##layer_name
718+
{
719+
'data_type': 'int',
720+
'bits': 4,
721+
'group_size': 32,
722+
'scheme': "asym", ## or sym
723+
}
724+
...
725+
}
726+
enable_full_range (bool): Whether to enable full range quantization (default is False).
727+
bs (int): Batch size for training (default is 8).
728+
amp (bool): Whether to use automatic mixed precision (default is True).
729+
device: The device to be used for tuning (default is "cuda:0").
730+
lr_scheduler: The learning rate scheduler to be used.
731+
dataloader: The dataloader for input data (to be supported in future).
732+
dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k").
733+
dataset_split (str): The split of the dataset to be used (default is "train").
734+
use_quant_input (bool): Whether to use quantized input data (default is True).
735+
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
736+
lr (float): The learning rate (default is 0.005).
737+
minmax_lr (float): The learning rate for min-max tuning (default is None).
738+
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
739+
iters (int): Number of iterations (default is 200).
740+
seqlen (int): Length of the sequence.
741+
n_samples (int): Number of samples (default is 512).
742+
sampler (str): The sampling method (default is "rand").
743+
seed (int): The random seed (default is 42).
744+
n_blocks (int): Number of blocks (default is 1).
745+
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
746+
not_use_best_mse (bool): Whether to use mean squared error (default is False).
747+
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
748+
data_type (str): The data type to be used (default is "int").
749+
**kwargs: Additional keyword arguments.
750+
751+
Returns:
752+
The quantized model.
753+
"""
754+
from auto_round import AutoRound # pylint: disable=E0401
755+
756+
rounder = AutoRound(
757+
model=model,
758+
tokenizer=tokenizer,
759+
bits=bits,
760+
group_size=group_size,
761+
scheme=scheme,
762+
weight_config=weight_config,
763+
enable_full_range=enable_full_range, ##for symmetric, TODO support later
764+
bs=bs,
765+
amp=amp,
766+
device=device,
767+
lr_scheduler=lr_scheduler,
768+
dataloader=dataloader, ## to support later
769+
dataset_name=dataset_name,
770+
dataset_split=dataset_split,
771+
use_quant_input=use_quant_input,
772+
enable_minmax_tuning=enable_minmax_tuning,
773+
lr=lr,
774+
minmax_lr=minmax_lr,
775+
low_gpu_mem_usage=low_gpu_mem_usage,
776+
iters=iters,
777+
seqlen=seqlen,
778+
n_samples=n_samples,
779+
sampler=sampler,
780+
seed=seed,
781+
n_blocks=n_blocks,
782+
gradient_accumulate_steps=gradient_accumulate_steps,
783+
not_use_best_mse=not_use_best_mse,
784+
dynamic_max_gap=dynamic_max_gap,
785+
data_type=data_type, ## only support data_type
786+
scale_dtype=scale_dtype,
787+
**kwargs,
788+
)
789+
qdq_model, weight_config = rounder.quantize()
790+
return qdq_model, weight_config

neural_compressor/config.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
),
6161
Optional("algorithm"): And(
6262
list, # TODO: allow AWQ+GPTQ algo
63-
lambda s: all(i in ["minmax", "RTN", "AWQ", "GPTQ", "TEQ"] for i in s),
63+
lambda s: all(i in ["minmax", "RTN", "AWQ", "GPTQ", "TEQ", "AUTOROUND"] for i in s),
6464
),
6565
Optional("bits"): And(list, lambda s: all(0 < i <= 8 and type(i) == int for i in s)),
6666
Optional("group_size"): And(list, lambda s: all(i >= -1 and i != 0 and type(i) == int for i in s)),
@@ -941,6 +941,12 @@ def teq_args(val=None):
941941
else:
942942
return {}
943943

944+
def autoround_args(val=None):
945+
if val is not None:
946+
return _check_value("autoround_args", val, dict)
947+
else:
948+
return {}
949+
944950
def fast_bias_correction(val=None):
945951
if val is not None:
946952
return _check_value("fast_bias_correction", val, bool)
@@ -1025,6 +1031,7 @@ def dedicated_qdq_pair(val=None):
10251031
"awq_args": awq_args,
10261032
"gptq_args": gptq_args,
10271033
"teq_args": teq_args,
1034+
"autoround_args": autoround_args,
10281035
}
10291036
self._recipes = {}
10301037
for k in RECIPES.keys():

test/adaptor/pytorch_adaptor/test_weight_only_adaptor.py

+63
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
from neural_compressor.utils.load_huggingface import export_compressed_model
1313
from neural_compressor.utils.pytorch import load
1414

15+
try:
16+
import auto_round
17+
18+
auto_round_installed = True
19+
except ImportError:
20+
auto_round_installed = False
21+
1522

1623
class Model(torch.nn.Module):
1724
def __init__(self):
@@ -738,6 +745,62 @@ def __iter__(self):
738745
out2 = q_model.model(input)
739746
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01))
740747

748+
@unittest.skipIf(not auto_round_installed, "auto_round module is not installed")
749+
def test_AutoRound_quant(self):
750+
from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader
751+
752+
tokenizer = transformers.AutoTokenizer.from_pretrained(
753+
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
754+
)
755+
dataloader = get_dataloader(
756+
tokenizer, seqlen=10, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k"
757+
)
758+
fp32_model = copy.deepcopy(self.gptj)
759+
760+
conf = PostTrainingQuantConfig(
761+
approach="weight_only",
762+
op_type_dict={
763+
".*": { # re.match
764+
"weight": {
765+
"dtype": "int",
766+
"bits": 4,
767+
"group_size": 32, # -1 (per-channel)
768+
"scheme": "sym",
769+
"algorithm": "AUTOROUND",
770+
},
771+
},
772+
},
773+
op_name_dict={
774+
".*lm_head": { # re.match
775+
"weight": {"dtype": "fp32"},
776+
},
777+
},
778+
recipes={
779+
"autoround_args": {
780+
"n_samples": 20,
781+
"amp": False,
782+
"seq_len": 10,
783+
"iters": 10,
784+
"scale_dtype": "fp32",
785+
"device": "cpu",
786+
},
787+
},
788+
)
789+
790+
input = torch.ones([1, 512], dtype=torch.long)
791+
fp32_model = copy.deepcopy(self.gptj)
792+
out1 = fp32_model(input)
793+
q_model = quantization.fit(
794+
fp32_model,
795+
conf,
796+
calib_dataloader=dataloader,
797+
)
798+
out2 = q_model.model(input)
799+
self.assertTrue(torch.allclose(out1[0], out2[0], atol=1e-01))
800+
self.assertTrue("transformer.h.0.attn.k_proj" in q_model.autoround_config.keys())
801+
self.assertTrue("scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys())
802+
self.assertTrue(torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"])
803+
741804

742805
if __name__ == "__main__":
743806
unittest.main()

0 commit comments

Comments
 (0)