Skip to content

Commit 06b211c

Browse files
committed
Assign module to device after quantization
Summary: Before, we were moving the module to the device and then quantizing it, now we quantize first and move to the device after. This was causing some of the huggingface integration tests to fail. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 1e473ed commit 06b211c

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

torchao/quantization/quant_api.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,9 @@ def quantize_(
487487
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
488488
)
489489
# this replaces inplace, so no need to reassign
490-
_fqn_to_config_handler(module, module_name, config, device)
490+
_fqn_to_config_handler(module, module_name, config)
491+
if device is not None:
492+
module.to(device=device)
491493
return
492494
if isinstance(config, AOBaseConfig):
493495
filter_fn = _is_linear if filter_fn is None else filter_fn
@@ -2451,7 +2453,6 @@ def _fqn_to_config_handler(
24512453
module: torch.nn.Module,
24522454
fqn: str,
24532455
config: FqnToConfig,
2454-
device: Optional[torch.device] = None,
24552456
):
24562457
"""This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig.
24572458
@@ -2460,17 +2461,13 @@ def _fqn_to_config_handler(
24602461
fqn (str): The fully qualified name of the module containing the parameters.
24612462
config (FqnToConfig): Configuration object containing regex patterns / fqn mapped
24622463
to quantization configurations.
2463-
device (Optional[torch.device]): The device to move the module to as part of quantization
24642464
24652465
Returns:
24662466
torch.nn.Module: The modified module with quantized parameters.
24672467
24682468
Raises:
24692469
NotImplementedError: If the quantization configuration is not yet supported for parameter quantization.
24702470
"""
2471-
if device is not None:
2472-
module = module.to(device)
2473-
24742471
parameter_config_found = False
24752472
top_level_params = []
24762473
for i, (parameter_name, param) in enumerate(list(module.named_parameters())):

0 commit comments

Comments
 (0)