@@ -170,12 +170,12 @@ def fold_pqs_to_weights(model):
170170class QuantRecipeHparam (Hparam ):
171171 """An Hparam for quantization recipes.
172172
173- In addition, this Hparam also:
173+ See :class:`Hparam <modelopt.torch.opt.hparam.Hparam>` for more details. In addition, this Hparam also:
174174
175- * Keeps a link to its quant_modules and score_modules and sets the quantizers for the
176- quant_modules based on the active recipe.
175+ * Keeps a link to its `` quant_modules`` and `` score_modules`` and sets the quantizers for the
176+ `` quant_modules`` based on the active recipe.
177177 * Keeps track of the importance of each recipe in a dict instead of a tensor.
178- * Registers itself with each score_module via the _hparams_for_scoring attribute.
178+ * Registers itself with each `` score_module`` via the `` _hparams_for_scoring`` attribute.
179179 """
180180
181181 def __init__ (
@@ -271,8 +271,14 @@ def attrs(self) -> list[str]:
271271 """Return the attributes of the hparam for repr."""
272272 return ["name" , * super ().attrs ]
273273
274+
274275class _AutoQuantizeBaseSearcher (BaseSearcher , ABC ):
275- """A base searcher for AutoQuantize algorithm."""
276+ """Base searcher for AutoQuantize algorithm."""
277+
278+ # This searcher finds optimal per-layer quantization by searching across quantization formats
279+ # for each quantizable module (quant module). Optionally, quant grouping rules can restrict
280+ # certain modules to share the same format. Sensitivity scores are computed from perturbations
281+ # at score modules. See AutoQuantizeGradientSearcher for detailed documentation.
276282
277283 candidate_stats : dict [str , dict [str , list [float ]]]
278284 best : dict [str , Any ]
@@ -383,25 +389,26 @@ def _apply_score_group_rule(self, name: str, rule) -> str | None:
383389 return None
384390
385391 def _get_score_module_from_name (
386- self , model : nn .Module , score_module_name : str , fallback_module : nn .Module
392+ self , model : nn .Module , score_module_name : str , quant_module : nn .Module
387393 ) -> nn .Module :
388394 """Get the actual score module object from its name.
389395
390396 Args:
391397 model: The model containing all modules
392398 score_module_name: The name of the score module to retrieve
393- fallback_module : The fallback module to use if score_module_name doesn't exist (typically the quant module)
399+ quant_module : The quantized module for which the score is estimated
394400
395401 Returns:
396- The score module object, or fallback_module if not found
402+ The score module object, or the quantized module itself if the score module is not found
397403 """
398404 try :
399405 score_module = model .get_submodule (score_module_name )
400406 return score_module
401407 except AttributeError :
402- # If score module doesn't exist, fall back to the provided fallback module
403- # This shouldn't happen with valid rules, but provide a safe fallback
404- return fallback_module
408+ warnings .warn (
409+ f"Score module '{ score_module_name } ' not found. Score will estimated from the quantized module itself."
410+ )
411+ return quant_module
405412
406413 def insert_hparams_after_merge_rules (self , model , quant_recipes , disabled_layers = None ):
407414 """Restrict the search space using the merge rules and insert the hparams for the model."""
@@ -459,20 +466,12 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers
459466 disabled = any (disabled for _ , _ , disabled , _ in module_info_list )
460467 score_modules = [score_module for _ , _ , _ , score_module in module_info_list ]
461468
462- hparam = (
463- QuantRecipeHparam (
464- None ,
465- quant_modules = quant_modules ,
466- score_modules = score_modules ,
467- name = str (group_key ),
468- )
469- if disabled
470- else QuantRecipeHparam (
471- quant_recipes ,
472- quant_modules = quant_modules ,
473- score_modules = score_modules ,
474- name = str (group_key ),
475- )
469+ quant_recipes = None if disabled else quant_recipes
470+ hparam = QuantRecipeHparam (
471+ quant_recipes ,
472+ quant_modules = quant_modules ,
473+ score_modules = score_modules ,
474+ name = str (group_key ),
476475 )
477476
478477 for module in quant_modules :
@@ -495,8 +494,8 @@ def _verify_constraint(self, search_recipes):
495494 )
496495
497496 @abstractmethod
498- def estimate_sensitivity_scores (self ):
499- """Estimate the sensitivity scores for the model ."""
497+ def estimate_sensitivity_scores (self ) -> None :
498+ """Estimate sensitivity scores and track them with Hparam ."""
500499
501500 def _run_func (self , func , num_iters = 1 , desc = "" ):
502501 for i , data in tqdm (
@@ -656,8 +655,6 @@ def run_search(self):
656655 QuantRecipe .fold_pqs_to_weights (self .model )
657656
658657
659-
660-
661658@torch .compile
662659def _get_auto_quantize_score (grad_output , output_diff ):
663660 return ((grad_output .float () ** 2 ) * (output_diff .float () ** 2 )).sum ()
@@ -675,13 +672,29 @@ class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher):
675672 scores while meeting the specified constraint. AutoQuantize uses Linear Programming Solver to find the
676673 optimal quantization configuration.
677674
678- The auto_quantize score for a layer quantization configuration is an approximation of model loss change change due
675+ The auto_quantize score for a layer quantization configuration is an approximation of model loss change due
679676 to quantizing the particular layer with the particular configuration.
680677 The approximation is based on taylor expansion of the loss function wrt to the quantized output of the layer and
681678 substitution of Fisher information for Hessian.
682679 This approximation is mathematically correct for models where the loss
683680 is a log likelihood loss such as BERT, GPT, etc. However, the auto_quantize score can still be used as a proxy
684681 for other models such as ResNet.
682+
683+ **Quant Modules:**
684+
685+ This searcher operates on quantizable modules (quant modules), which are typically Linear or Conv layers
686+ that support quantization. Optionally, grouping rules can be applied to ensure certain layers share the same
687+ quantization format (e.g., Q, K, V projections in the same attention layer). For details on quant_grouping_rules
688+ and customization, see the :meth:`auto_quantize <modelopt.torch.quantization.model_quant.auto_quantize>`
689+ API documentation.
690+
691+ **Score Modules:**
692+
693+ By default, for each quant module, its sensitivity score is estimated using that module's output perturbation.
694+ However, the sensitivity can also be estimated by looking at perturbation at a separate point in the neural
695+ network (score module). This is helpful in some cases such as MoEs for speed and lower memory consumption.
696+ Since all experts are already restricted to the same quant format by quant grouping rules, their sensitivity
697+ can be estimated together at a single point (e.g., the MLP output level).
685698 """
686699
687700 score_module_rules = [
@@ -872,8 +885,8 @@ def cleanup_params_after_score_estimation(name, param, params_metadata):
872885 del params_metadata
873886 gc .collect ()
874887
875- def estimate_sensitivity_scores (self ):
876- """Estimate the sensitivity scores for the model ."""
888+ def estimate_sensitivity_scores (self ) -> None :
889+ """Estimate sensitivity scores using hessian approximation ."""
877890 self .model .eval ()
878891
879892 def _default_is_param_grad_enabled (pname , model ):
0 commit comments