Skip to content

Commit e38a551

Browse files
committed
updated docs; code clean up
Signed-off-by: realAsma <[email protected]> clean ups Signed-off-by: realAsma <[email protected]>
1 parent b7bd107 commit e38a551

File tree

2 files changed

+46
-33
lines changed

2 files changed

+46
-33
lines changed

modelopt/torch/opt/hparam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __eq__(self, other) -> bool:
4848
class Hparam:
4949
"""A base hyperparameter of a DynamicModule.
5050
51-
An example of such a Hparam could be an hparam with identity dependencies.
51+
Keeps track of hyperparameter values and their importance, which can be used for search algorithms.
5252
"""
5353

5454
Importance = Union[torch.Tensor, None] # noqa: UP007

modelopt/torch/quantization/algorithms.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,12 @@ def fold_pqs_to_weights(model):
170170
class QuantRecipeHparam(Hparam):
171171
"""An Hparam for quantization recipes.
172172
173-
In addition, this Hparam also:
173+
See :class:`Hparam <modelopt.torch.opt.hparam.Hparam>` for more details. In addition, this Hparam also:
174174
175-
* Keeps a link to its quant_modules and score_modules and sets the quantizers for the
176-
quant_modules based on the active recipe.
175+
* Keeps a link to its ``quant_modules`` and ``score_modules`` and sets the quantizers for the
176+
``quant_modules`` based on the active recipe.
177177
* Keeps track of the importance of each recipe in a dict instead of a tensor.
178-
* Registers itself with each score_module via the _hparams_for_scoring attribute.
178+
* Registers itself with each ``score_module`` via the ``_hparams_for_scoring`` attribute.
179179
"""
180180

181181
def __init__(
@@ -271,8 +271,14 @@ def attrs(self) -> list[str]:
271271
"""Return the attributes of the hparam for repr."""
272272
return ["name", *super().attrs]
273273

274+
274275
class _AutoQuantizeBaseSearcher(BaseSearcher, ABC):
275-
"""A base searcher for AutoQuantize algorithm."""
276+
"""Base searcher for AutoQuantize algorithm."""
277+
278+
# This searcher finds optimal per-layer quantization by searching across quantization formats
279+
# for each quantizable module (quant module). Optionally, quant grouping rules can restrict
280+
# certain modules to share the same format. Sensitivity scores are computed from perturbations
281+
# at score modules. See AutoQuantizeGradientSearcher for detailed documentation.
276282

277283
candidate_stats: dict[str, dict[str, list[float]]]
278284
best: dict[str, Any]
@@ -383,25 +389,26 @@ def _apply_score_group_rule(self, name: str, rule) -> str | None:
383389
return None
384390

385391
def _get_score_module_from_name(
386-
self, model: nn.Module, score_module_name: str, fallback_module: nn.Module
392+
self, model: nn.Module, score_module_name: str, quant_module: nn.Module
387393
) -> nn.Module:
388394
"""Get the actual score module object from its name.
389395
390396
Args:
391397
model: The model containing all modules
392398
score_module_name: The name of the score module to retrieve
393-
fallback_module: The fallback module to use if score_module_name doesn't exist (typically the quant module)
399+
quant_module: The quantized module for which the score is estimated
394400
395401
Returns:
396-
The score module object, or fallback_module if not found
402+
The score module object, or the quantized module itself if the score module is not found
397403
"""
398404
try:
399405
score_module = model.get_submodule(score_module_name)
400406
return score_module
401407
except AttributeError:
402-
# If score module doesn't exist, fall back to the provided fallback module
403-
# This shouldn't happen with valid rules, but provide a safe fallback
404-
return fallback_module
408+
warnings.warn(
409+
f"Score module '{score_module_name}' not found. Score will estimated from the quantized module itself."
410+
)
411+
return quant_module
405412

406413
def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers=None):
407414
"""Restrict the search space using the merge rules and insert the hparams for the model."""
@@ -459,20 +466,12 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers
459466
disabled = any(disabled for _, _, disabled, _ in module_info_list)
460467
score_modules = [score_module for _, _, _, score_module in module_info_list]
461468

462-
hparam = (
463-
QuantRecipeHparam(
464-
None,
465-
quant_modules=quant_modules,
466-
score_modules=score_modules,
467-
name=str(group_key),
468-
)
469-
if disabled
470-
else QuantRecipeHparam(
471-
quant_recipes,
472-
quant_modules=quant_modules,
473-
score_modules=score_modules,
474-
name=str(group_key),
475-
)
469+
quant_recipes = None if disabled else quant_recipes
470+
hparam = QuantRecipeHparam(
471+
quant_recipes,
472+
quant_modules=quant_modules,
473+
score_modules=score_modules,
474+
name=str(group_key),
476475
)
477476

478477
for module in quant_modules:
@@ -495,8 +494,8 @@ def _verify_constraint(self, search_recipes):
495494
)
496495

497496
@abstractmethod
498-
def estimate_sensitivity_scores(self):
499-
"""Estimate the sensitivity scores for the model."""
497+
def estimate_sensitivity_scores(self) -> None:
498+
"""Estimate sensitivity scores and track them with Hparam."""
500499

501500
def _run_func(self, func, num_iters=1, desc=""):
502501
for i, data in tqdm(
@@ -656,8 +655,6 @@ def run_search(self):
656655
QuantRecipe.fold_pqs_to_weights(self.model)
657656

658657

659-
660-
661658
@torch.compile
662659
def _get_auto_quantize_score(grad_output, output_diff):
663660
return ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum()
@@ -675,13 +672,29 @@ class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher):
675672
scores while meeting the specified constraint. AutoQuantize uses Linear Programming Solver to find the
676673
optimal quantization configuration.
677674
678-
The auto_quantize score for a layer quantization configuration is an approximation of model loss change change due
675+
The auto_quantize score for a layer quantization configuration is an approximation of model loss change due
679676
to quantizing the particular layer with the particular configuration.
680677
The approximation is based on taylor expansion of the loss function wrt to the quantized output of the layer and
681678
substitution of Fisher information for Hessian.
682679
This approximation is mathematically correct for models where the loss
683680
is a log likelihood loss such as BERT, GPT, etc. However, the auto_quantize score can still be used as a proxy
684681
for other models such as ResNet.
682+
683+
**Quant Modules:**
684+
685+
This searcher operates on quantizable modules (quant modules), which are typically Linear or Conv layers
686+
that support quantization. Optionally, grouping rules can be applied to ensure certain layers share the same
687+
quantization format (e.g., Q, K, V projections in the same attention layer). For details on quant_grouping_rules
688+
and customization, see the :meth:`auto_quantize <modelopt.torch.quantization.model_quant.auto_quantize>`
689+
API documentation.
690+
691+
**Score Modules:**
692+
693+
By default, for each quant module, its sensitivity score is estimated using that module's output perturbation.
694+
However, the sensitivity can also be estimated by looking at perturbation at a separate point in the neural
695+
network (score module). This is helpful in some cases such as MoEs for speed and lower memory consumption.
696+
Since all experts are already restricted to the same quant format by quant grouping rules, their sensitivity
697+
can be estimated together at a single point (e.g., the MLP output level).
685698
"""
686699

687700
score_module_rules = [
@@ -872,8 +885,8 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
872885
del params_metadata
873886
gc.collect()
874887

875-
def estimate_sensitivity_scores(self):
876-
"""Estimate the sensitivity scores for the model."""
888+
def estimate_sensitivity_scores(self) -> None:
889+
"""Estimate sensitivity scores using hessian approximation."""
877890
self.model.eval()
878891

879892
def _default_is_param_grad_enabled(pname, model):

0 commit comments

Comments
 (0)