Skip to content

Commit 0ec5e20

Browse files
realAsmaAsma Kuriparambil Thekkumpate
authored andcommitted
[1/N] Refactored AutoQuantizeSearcher to _AutoQuantizeBaseSearcher & AutoQuantizeGradientSearcher; seperated quant modules and score modules (NVIDIA#586)
## What does this PR do? **Type of change:** Refator; Minor new feature **Overview:** ? 1. Refactored AutoQuantizeSearcher to _AutoQuantizeBaseSearcher & AutoQuantizeGradientSearcher - Prepares architecture for additional search methods. 2. seperated quant modules and score modules - separate quantization modules from scoring modules, enabling auto-quantization to measure sensitivity at parent layers (e.g., MLP output for MoE experts) rather than individual ops. 3. Also see NVIDIA#592 and NVIDIA#588 ## Testing See unittests; `tests/unit/torch/quantization/test_autoquant.py` and `tests/unit/torch/quantization/plugins/test_huggingface.py` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Not Required ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for score modules in quantization workflows. * Added optional naming for quantization recipes. * **Bug Fixes** * Improved quantization grouping rules documentation with clearer configuration examples. * **Refactor** * Renamed quantization module parameters for improved clarity. * Enhanced quantization search architecture for better scalability. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: realAsma <[email protected]> Co-authored-by: Asma Kuriparambil Thekkumpate <[email protected]>
1 parent 3e725c3 commit 0ec5e20

File tree

15 files changed

+1170
-341
lines changed

15 files changed

+1170
-341
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Model Optimizer Changelog (Linux)
1414
- Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``).
1515
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
17+
- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs <https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.auto_quantize>`_ for more details.
18+
- Add support for saving and resuming auto_quantize search state. This speeds up the auto_quantize process by skipping the score estimation step if the search state is provided.
1719
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1820
- Add support for PyTorch Geometric quantization.
1921
- Add per tensor and per channel MSE calibrator support.

examples/llm_eval/gen_model_answer.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@ def get_model_answers(
201201
tokenizer,
202202
args.calib_batch_size,
203203
args.calib_size,
204-
args.auto_quantize_bits,
205204
test_generated=False,
205+
auto_quantize_bits=args.auto_quantize_bits,
206+
auto_quantize_method=args.auto_quantize_method,
207+
auto_quantize_score_size=args.auto_quantize_score_size,
208+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
206209
)
207210

208211
for question in tqdm(questions):
@@ -450,6 +453,36 @@ def reorg_answer_file(answer_file):
450453
"regular quantization without auto_quantize search will be applied."
451454
),
452455
)
456+
parser.add_argument(
457+
"--auto_quantize_method",
458+
type=str,
459+
default="gradient",
460+
choices=["gradient", "kl_div"],
461+
help=(
462+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
463+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
464+
"quantized model outputs (no labels required). Default: 'gradient'"
465+
),
466+
)
467+
parser.add_argument(
468+
"--auto_quantize_score_size",
469+
type=int,
470+
default=128,
471+
help=(
472+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
473+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
474+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
475+
),
476+
)
477+
parser.add_argument(
478+
"--auto_quantize_checkpoint",
479+
type=str,
480+
default=None,
481+
help=(
482+
"Path to checkpoint file for saving/restoring auto_quantize search state "
483+
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
484+
),
485+
)
453486
parser.add_argument(
454487
"--trust_remote_code",
455488
help="Set trust_remote_code for Huggingface models and tokenizers",

examples/llm_eval/lm_eval_hf.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5353

5454
quant_cfg = arg_dict.pop("quant_cfg", None)
5555
auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None)
56+
auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient")
57+
auto_quantize_score_size = arg_dict.pop("auto_quantize_score_size", 128)
58+
auto_quantize_checkpoint = arg_dict.pop("auto_quantize_checkpoint", None)
5659
calib_batch_size = arg_dict.pop("calib_batch_size", None)
5760
calib_size = arg_dict.pop("calib_size", 512)
5861
compress = arg_dict.pop("compress", False)
@@ -81,8 +84,11 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8184
batch_size=calib_batch_size,
8285
calib_size=calib_size,
8386
auto_quantize_bits=auto_quantize_bits,
87+
auto_quantize_method=auto_quantize_method,
88+
auto_quantize_score_size=auto_quantize_score_size,
8489
test_generated=False,
8590
compress=compress,
91+
auto_quantize_checkpoint=auto_quantize_checkpoint,
8692
)
8793

8894
return model_obj
@@ -101,6 +107,12 @@ def setup_parser_with_modelopt_args():
101107
"comma-separated list of quantization quantization formats that will be searched by `auto_quantize`"
102108
),
103109
)
110+
parser.add_argument(
111+
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
112+
)
113+
parser.add_argument(
114+
"--calib_size", type=int, help="Calibration size for quantization", default=512
115+
)
104116
parser.add_argument(
105117
"--auto_quantize_bits",
106118
type=float,
@@ -110,10 +122,30 @@ def setup_parser_with_modelopt_args():
110122
),
111123
)
112124
parser.add_argument(
113-
"--calib_batch_size", type=int, help="Batch size for quantization calibration"
125+
"--auto_quantize_method",
126+
type=str,
127+
default="gradient",
128+
choices=["gradient", "kl_div"],
129+
help=(
130+
"Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method "
131+
"(requires labels in dataset). 'kl_div' uses KL divergence between original and "
132+
"quantized model outputs (no labels required). Default: 'gradient'"
133+
),
114134
)
115135
parser.add_argument(
116-
"--calib_size", type=int, help="Calibration size for quantization", default=512
136+
"--auto_quantize_score_size",
137+
type=int,
138+
default=128,
139+
help=(
140+
"Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on "
141+
"sensitivity score estimation, so reducing this speeds it up while only minimally affecting "
142+
"final model accuracy compared to lowering --calib_size (the number of samples used for calibration)."
143+
),
144+
)
145+
parser.add_argument(
146+
"--auto_quantize_checkpoint",
147+
type=str,
148+
help=("Path to checkpoint file for saving/restoring auto_quantize search state. "),
117149
)
118150
parser.add_argument(
119151
"--compress",
@@ -139,6 +171,9 @@ def setup_parser_with_modelopt_args():
139171
{
140172
"quant_cfg": args.quant_cfg,
141173
"auto_quantize_bits": args.auto_quantize_bits,
174+
"auto_quantize_method": args.auto_quantize_method,
175+
"auto_quantize_score_size": args.auto_quantize_score_size,
176+
"auto_quantize_checkpoint": args.auto_quantize_checkpoint,
142177
"calib_batch_size": args.calib_batch_size,
143178
"calib_size": args.calib_size,
144179
"compress": args.compress,

examples/llm_eval/mmlu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def main(
227227
batch_size: int = 0,
228228
calib_size: int = 512,
229229
dtype: str = "bfloat16",
230+
auto_quantize_method: str = "gradient",
231+
auto_quantize_score_size: int = 128,
232+
auto_quantize_checkpoint: str | None = None,
230233
**kwargs,
231234
):
232235
random.seed(RAND_SEED)
@@ -281,6 +284,9 @@ def main(
281284
batch_size=batch_size,
282285
calib_size=calib_size,
283286
auto_quantize_bits=auto_quantize_bits,
287+
auto_quantize_method=auto_quantize_method,
288+
auto_quantize_score_size=auto_quantize_score_size,
289+
auto_quantize_checkpoint=auto_quantize_checkpoint,
284290
)
285291

286292
for subject in tqdm(subjects):

examples/llm_eval/quantization_utils.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,11 @@ def _quantize_model_with_dataset(
6666
quant_cfg: str | list[str],
6767
calib_dataset,
6868
auto_quantize_bits=None,
69+
auto_quantize_method="gradient",
70+
auto_quantize_score_size=128,
6971
batch_size=1,
7072
compress=False,
73+
auto_quantize_checkpoint=None,
7174
):
7275
if hasattr(lm, "gpt2"):
7376
net = lm.gpt2
@@ -81,23 +84,42 @@ def _quantize_model_with_dataset(
8184
getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE"
8285
]
8386

84-
def loss_func(output, data):
85-
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
86-
# which contains the loss attribute.
87-
return output.loss
87+
# Configure forward_step and loss_func based on method
88+
if auto_quantize_method == "gradient":
89+
# For gradient-based method, return full output with loss
90+
def forward_step(model, batch):
91+
return model(**batch)
92+
93+
def loss_func(output, data):
94+
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
95+
# which contains the loss attribute.
96+
return output.loss
97+
elif auto_quantize_method == "kl_div":
98+
# For KL divergence method, return only logits
99+
def forward_step(model, batch):
100+
return model(**batch).logits
101+
102+
loss_func = None # KL divergence doesn't need a custom loss function
103+
else:
104+
raise ValueError(
105+
f"Invalid auto_quantize_method: {auto_quantize_method}. "
106+
"Must be 'gradient' or 'kl_div'"
107+
)
88108

89109
net, _ = mtq.auto_quantize(
90110
net,
91111
constraints={"effective_bits": auto_quantize_bits},
92112
quantization_formats=quant_cfg_for_search,
93113
data_loader=calib_dataset,
94-
forward_step=lambda model, batch: model(**batch),
114+
forward_step=forward_step,
95115
loss_func=loss_func,
96116
num_calib_steps=len(calib_dataset),
97-
num_score_steps=min(
98-
len(calib_dataset), 128 // batch_size
99-
), # Limit the number of score steps to avoid long calibration time
117+
# Most time is spent on score estimation; fewer samples speed it up with little accuracy impact.
118+
num_score_steps=min(len(calib_dataset), max(auto_quantize_score_size // batch_size, 1)),
100119
verbose=True,
120+
method=auto_quantize_method,
121+
# disabled_layers=["*lm_head*", "*mlp.gate.*"],
122+
checkpoint=auto_quantize_checkpoint,
101123
)
102124
else:
103125
mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type]
@@ -141,10 +163,13 @@ def quantize_model(
141163
tokenizer,
142164
batch_size,
143165
calib_size,
144-
auto_quantize_bits=None,
145166
data="cnn_dailymail",
146167
test_generated=True,
147168
compress=False,
169+
auto_quantize_bits=None,
170+
auto_quantize_method="gradient",
171+
auto_quantize_score_size=128,
172+
auto_quantize_checkpoint=None,
148173
):
149174
"""Quantizes the model with the provided calibration dataset.
150175
@@ -155,10 +180,14 @@ def quantize_model(
155180
tokenizer: the tokenizer.
156181
batch_size: the calibration batch size for each calibration inference run.
157182
calib_size: the total calibration dataset size.
158-
auto_quantize_bits: The effective bits constraint for auto_quantize.
159183
data: the name of the calibration dataset.
160184
test_generated: If ``True``, test the generated text before and after quantization.
161185
compress: If ``True``, compress the model after quantization.
186+
auto_quantize_bits: The effective bits constraint for auto_quantize.
187+
auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div').
188+
auto_quantize_score_size: Number of samples used for auto_quantize scoring.
189+
auto_quantize_checkpoint: Path to checkpoint file for saving/restoring auto_quantize search state
190+
(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified.
162191
"""
163192
if "AWQ" in quant_cfg:
164193
print(
@@ -170,8 +199,10 @@ def quantize_model(
170199
if hasattr(model, "model"):
171200
device = model.model.device
172201

202+
is_gradient_based = auto_quantize_bits is not None and auto_quantize_method == "gradient"
203+
173204
if batch_size == 0:
174-
if auto_quantize_bits is not None or torch.distributed.is_initialized():
205+
if is_gradient_based or torch.distributed.is_initialized():
175206
raise ValueError("We dont support automatic batch size inference for this case.")
176207

177208
net = model.gpt2 if hasattr(model, "gpt2") else model.model
@@ -186,15 +217,23 @@ def quantize_model(
186217
batch_size=batch_size,
187218
num_samples=calib_size,
188219
device=device,
189-
include_labels=auto_quantize_bits is not None,
220+
include_labels=is_gradient_based,
190221
)
191222

192223
if test_generated:
193224
input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0])
194225
generated_str_before_ptq = model.run(input_str)
195226

196227
_quantize_model_with_dataset(
197-
model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress
228+
model,
229+
quant_cfg,
230+
calib_dataloader,
231+
auto_quantize_bits,
232+
auto_quantize_method,
233+
auto_quantize_score_size,
234+
batch_size,
235+
compress,
236+
auto_quantize_checkpoint,
198237
)
199238

200239
if test_generated:

0 commit comments

Comments
 (0)