Skip to content

Commit 85a6113

Browse files
authored
Enhance Auto-Round (#870)
* Bring `torch.compile` to `quant_block_v2_`. (#18) Signed-off-by: yiliu30 <[email protected]> * Add `AO_USE_DETERMINISTIC_ALGORITHMS` for reproducing results (#19) Signed-off-by: yiliu30 <[email protected]> * Add `gradient_accumulate_steps` and update results (#20) Signed-off-by: yiliu30 <[email protected]> * update the readme Signed-off-by: yiliu30 <[email protected]> * udpate Signed-off-by: yiliu30 <[email protected]> * update the desc Signed-off-by: yiliu30 <[email protected]> * rename `train_bs` to `batch_size` Signed-off-by: yiliu30 <[email protected]> * update the eval Signed-off-by: yiliu30 <[email protected]> * update Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
1 parent df358ce commit 85a6113

File tree

6 files changed

+201
-61
lines changed

6 files changed

+201
-61
lines changed

torchao/_models/llama/eval.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,28 @@ def run_evaluation(
128128

129129
_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
130130
# parse args from quantization string:
131-
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
131+
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>
132132
_quant_args = quantization.split("-")
133-
_default_quant_args = [False, 200, 128, 8, 2048, 128]
133+
_default_quant_args = [False, 200, 128, 8, 2048, 128, 1, 0]
134134
_model_devie = _quant_args[1] if len(_quant_args) > 1 else device
135135
_quant_args = _quant_args[2:]
136-
quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [
137-
int(x) for x in _quant_args
138-
] + _default_quant_args[len(_quant_args) :]
136+
(
137+
quant_lm_head,
138+
iters,
139+
groupsize,
140+
batch_size,
141+
seqlen,
142+
nsamples,
143+
grad_acc_steps,
144+
compile_optimization_process,
145+
) = [int(x) for x in _quant_args] + _default_quant_args[len(_quant_args) :]
139146
model = model.to(_model_devie)
140147
print(
141148
(
142149
f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, "
143-
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})"
150+
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples}, "
151+
f"gradient_accumulate_steps={grad_acc_steps}, "
152+
f"compile_optimization_process={compile_optimization_process})"
144153
)
145154
)
146155
with torch.device(_model_devie):
@@ -161,9 +170,11 @@ def run_evaluation(
161170
is_target_module=is_target_module,
162171
bits=4,
163172
seqlen=seqlen,
164-
bs=batch_size,
173+
batch_size=batch_size,
165174
iters=iters,
166175
nsamples=nsamples,
176+
gradient_accumulate_steps=grad_acc_steps,
177+
compile_optimization_process=compile_optimization_process == 1,
167178
)
168179
model.to(device)
169180
model.reset_caches()
@@ -195,9 +206,10 @@ def run_evaluation(
195206
"--quantization",
196207
type=str,
197208
help=(
198-
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, "
199-
"autoquant, autoquant-int4, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, "
200-
"sparse-marlin, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>"
209+
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
210+
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
211+
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
212+
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>"
201213
),
202214
)
203215
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')

torchao/prototype/autoround/README.md

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@ Auto-Round is an advanced quantization algorithm designed for low-bit LLM infere
1010
python autoround_llm.py -m /model/name/or/path
1111
```
1212

13+
This script allows you to apply `Auto-Round` on a given model directly, more configurations options are list below:
14+
15+
| Argument |Default | Description |
16+
|------------------------------------|----------------------------|-------------------------------------------------------------------|
17+
| `model_name_or_path` |`"facebook/opt-125m"` | Pretrained model name or path |
18+
| `dataset_name` | `"NeelNanda/pile-10k"` | Dataset name for calibration |
19+
| `iters` | 200 | Number of steps for optimizing each block |
20+
| `bits` | 4 | Number of bits for quantization |
21+
| `batch_size` | 8 | Batch size for calibration |
22+
| `nsamples` | 128 | Number of samples for calibration process |
23+
| `seqlen` | 2048 | Sequence length for each samples |
24+
| `group_size` | 128 | Group size for quantization |
25+
| `gradient_accumulate_steps` | 1 | Number of steps for accumulating gradients <br> before performing the backward pass |
26+
| `quant_lm_head` | `False` | Whether to quantize the `lm_head` |
27+
| `use_optimized_layer_output` | `False` | Whether to use optimized layer output as input for the next layer |
28+
| `compile_optimization_process` | `False` | Whether to compile the optimization process |
29+
| `model_device` | `"cuda"` | Device for loading the float model (choices: `cpu`, `cuda`) |
30+
1331

1432
> [!NOTE]
1533
> Before running, ensure you have installed the `auto-round` with `pip install -r requirements.txt`.
@@ -71,31 +89,35 @@ quantize_(model, apply_auto_round(), is_target_module)
7189

7290
## End-to-End Results
7391
### [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
74-
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
75-
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
76-
| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 |
77-
| auto-round-4bit | 0.6988 | 0.6533 | 0.7949 | 0.7372 | 0.5837 | 0.7250 |
78-
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |
92+
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
93+
| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- |
94+
| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 |
95+
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |
96+
| autoround-4bit | 0.6996 | 0.6669 | 0.7916 | 0.7285 | 0.5846 | 0.7262 |
97+
| autoround-4bit* | 0.7010 | 0.6621 | 0.7976 | 0.7316 | 0.5847 | 0.7291 |
7998

8099
### [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
81-
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
82-
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
83-
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
84-
| auto-round-4bit | 0.6818 | 0.6232 | 0.7862 | 0.7230 | 0.5661 | 0.7105 |
85-
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |
100+
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
101+
| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- |
102+
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
103+
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |
104+
| autoround-4bit | 0.6796 | 0.6237 | 0.7758 | 0.7198 | 0.5664 | 0.7122 |
105+
| autoround-4bit* | 0.6827 | 0.6273 | 0.7737 | 0.7348 | 0.5657 | 0.7120 |
86106

87107

88108
### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
89-
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
90-
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
91-
| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.577 | 0.7070 |
92-
| auto-round-4bit | 0.6327 | 0.4534 | 0.7590 | 0.6661 | 0.5706 | 0.7143 |
93-
| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 |
109+
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
110+
| ---------------- | ------ | ------ | ------ | ---------- | --------- | -------------- |
111+
| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.5770 | 0.7070 |
112+
| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 |
113+
| autoround-4bit | 0.6311 | 0.4548 | 0.7606 | 0.6614 | 0.5717 | 0.7072 |
114+
| autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 |
94115

95116
> [!NOTE]
96-
> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`. <br>
97-
> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`.
98-
> - If the model includes operations without a deterministic implementation (such as Flash Attention), the results may differ slightly.
117+
> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`int4_weight_only(group_size=128)`) while leaving the `lm-head` unquantized. <br>
118+
> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start). <br>
119+
> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `batch_size=4`, which accumulating two batches(4 samples per batch) before performing the backward pass. <br>
120+
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`.
99121
100122

101123
## Credits

torchao/prototype/autoround/autoround_llm.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import argparse
22
import logging
3+
from typing import Optional
34

45
import torch
56

67
import torchao
78
import torchao.prototype.autoround.utils as ar_utils
8-
99
from torchao.prototype.autoround.core import (
1010
apply_auto_round,
1111
prepare_model_for_applying_auto_round_,
@@ -26,9 +26,11 @@ def quantize_model_with_autoround_(
2626
iters: int = 200,
2727
seqlen: int = 2048,
2828
dataset_name: str = "NeelNanda/pile-10k",
29-
bs: int = 8,
29+
batch_size: int = 8,
3030
nsamples: int = 128,
3131
use_optimized_layer_output: bool = False,
32+
gradient_accumulate_steps: Optional[int] = 1,
33+
compile_optimization_process: Optional[bool] = False,
3234
):
3335
# Step 1. Prepare the model for applying auto-round
3436

@@ -42,6 +44,8 @@ def quantize_model_with_autoround_(
4244
group_size,
4345
iters,
4446
use_optimized_layer_output,
47+
gradient_accumulate_steps,
48+
compile_optimization_process,
4549
device=device,
4650
)
4751

@@ -50,7 +54,7 @@ def quantize_model_with_autoround_(
5054
tokenizer,
5155
seqlen=seqlen,
5256
dataset_name=dataset_name,
53-
bs=bs,
57+
bs=batch_size,
5458
nsamples=nsamples,
5559
)
5660
input_ids_lst = []
@@ -104,9 +108,11 @@ def main(args):
104108
iters=args.iters,
105109
seqlen=args.seqlen,
106110
dataset_name=args.dataset_name,
107-
bs=args.train_bs,
111+
batch_size=args.batch_size,
108112
nsamples=args.nsamples,
109113
use_optimized_layer_output=args.use_optimized_layer_output,
114+
gradient_accumulate_steps=args.gradient_accumulate_steps,
115+
compile_optimization_process=args.compile_optimization_process,
110116
)
111117
# Revert the `use_cache` for generation stage.
112118
model.config.use_cache = True
@@ -124,7 +130,7 @@ def main(args):
124130
"--model_name_or_path",
125131
type=str,
126132
default="facebook/opt-125m",
127-
help="Model name or path",
133+
help="Pretrained model name or path",
128134
)
129135
parser.add_argument(
130136
"--dataset_name",
@@ -136,37 +142,59 @@ def main(args):
136142
"--iters",
137143
default=200,
138144
type=int,
139-
help="Number of iterations for auto-round optimization",
145+
help="Number of steps for optimizing each block",
140146
)
141147
parser.add_argument(
142148
"--bits", default=4, type=int, help="Number of bits for quantization"
143149
)
144150
parser.add_argument(
145-
"--train_bs", default=8, type=int, help="Batch size for auto-round optimization"
151+
"--batch_size", default=8, type=int, help="Batch size for calibration"
146152
)
147153
parser.add_argument(
148154
"--nsamples",
149155
default=128,
150156
type=int,
151157
help="Number of samples for calibration process",
152158
)
159+
parser.add_argument(
160+
"--group_size",
161+
default=128,
162+
type=int,
163+
help="Group size for quantization",
164+
)
153165
parser.add_argument(
154166
"--seqlen",
155167
default=2048,
156168
type=int,
157-
help="Sequence length for calibration process",
169+
help="Sequence length for each samples",
170+
)
171+
parser.add_argument(
172+
"--gradient_accumulate_steps",
173+
default=1,
174+
type=int,
175+
help=(
176+
"Number of steps for accumulating gradients before performing"
177+
"the backward pass when optimizing each target module"
178+
),
158179
)
159180
parser.add_argument(
160181
"--quant_lm_head",
161182
default=False,
162183
action="store_true",
163-
help="Quantize the `lm_head` or not",
184+
help="Whether to quantize the `lm_head`",
164185
)
165186
parser.add_argument(
166187
"--use_optimized_layer_output",
167188
default=False,
168189
action="store_true",
169-
help="Use the optimized layer output for next layer or not",
190+
help="Whether to use optimized layer output as input for the next layer",
191+
)
192+
parser.add_argument(
193+
"-c",
194+
"--compile_optimization_process",
195+
default=False,
196+
action="store_true",
197+
help="Whether to compile the optimization process",
170198
)
171199
parser.add_argument(
172200
"-d",

torchao/prototype/autoround/core.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class _AutoRoundConfig:
2020
group_size: int = 128
2121
iters: int = 200
2222
use_optimized_layer_output: bool = False
23+
gradient_accumulate_steps: int = 1
24+
compile_optimization_process: bool = False
2325

2426

2527
_auto_round_config = _AutoRoundConfig()
@@ -82,6 +84,8 @@ def prepare_model_for_applying_auto_round_(
8284
group_size: int = 128,
8385
iters: int = 200,
8486
use_optimized_layer_output: bool = False,
87+
gradient_accumulate_steps: Optional[int] = 1,
88+
compile_optimization_process: Optional[bool] = False,
8589
device: Optional[torch.types.Device] = None,
8690
):
8791
"""Prepares the model for applying auto round optimization.
@@ -94,7 +98,10 @@ def prepare_model_for_applying_auto_round_(
9498
group_size (int, optional): The group size for quantization. Defaults to 128.
9599
iters (int, optional): The number of iterations for optimization. Defaults to 200.
96100
use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False.
97-
device (Optional[torch.types.Device], optional): The device to use for accelrating optimization and calibration.
101+
gradient_accumulate_steps (Optional[int]): Number of steps for accumulating gradients before
102+
performing the backward pass when optimizing each target module. Defaults to 1.
103+
compile_optimization_process (Optional[bool]): Whether to compile the optimization process. Defaults to False.
104+
device (Optional[torch.types.Device]): The device to use for accelrating optimization and calibration.
98105
Defaults to None.
99106
"""
100107
_multi_tensor_config.device = device
@@ -105,6 +112,8 @@ def prepare_model_for_applying_auto_round_(
105112
_auto_round_config.group_size = group_size
106113
_auto_round_config.iters = iters
107114
_auto_round_config.use_optimized_layer_output = use_optimized_layer_output
115+
_auto_round_config.gradient_accumulate_steps = gradient_accumulate_steps
116+
_auto_round_config.compile_optimization_process = compile_optimization_process
108117

109118
logging.warning(f"config {_auto_round_config}")
110119

@@ -172,7 +181,7 @@ def to_uintx_weight(input_float):
172181
quant_min = 0
173182
quant_max = _auto_round_config.bits**2 - 1
174183
block_size = (1, observed_linear.group_size)
175-
from torchao.dtypes.uintx.Uintx import (
184+
from torchao.dtypes.uintx.uintx import (
176185
_BIT_WIDTH_TO_DTYPE,
177186
UintxLayoutType,
178187
)
@@ -312,9 +321,12 @@ def _apply_auto_round_optimization(
312321
bits=config.bits,
313322
iters=config.iters,
314323
group_size=config.group_size,
324+
gradient_accumulate_steps=config.gradient_accumulate_steps,
315325
amp=True,
316326
model_dtype=next(block.parameters()).dtype,
317327
)
328+
if config.compile_optimization_process:
329+
rounder.quant_block_v2_ = torch.compile(rounder.quant_block_v2_)
318330

319331
with torch.enable_grad():
320332
rounder.quant_block_v2_(
@@ -326,7 +338,7 @@ def _apply_auto_round_optimization(
326338
block.to(orig_device)
327339

328340

329-
@ar_utils.dump_elapsed_time()
341+
@ar_utils.dump_elapsed_time(record=True)
330342
@torch.no_grad()
331343
def apply_auto_round_optimization(
332344
module: torch.nn.Module,

0 commit comments

Comments
 (0)