Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,41 @@ def reset_memory():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)

from torchao.quantization.quant_api import (
CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS,
)

@common_utils.parametrize("config", CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS)
def test_fqn_to_config_supported_param_configs(self, config):
"""Test that all supported parameter configs are in FqnToConfig."""

from torchao.utils import (
TorchAOBaseTensor,
)

torchao_tensor_types = (TorchAOBaseTensor, AffineQuantizedTensor)
m = ToyLinearModel(m=128, k=128, n=128)
m.linear1.register_parameter(
"custom_param_name", torch.nn.Parameter(torch.randn(m.linear1.weight.shape))
)
m = m.cuda().bfloat16()

fqn_config = FqnToConfig(
{
"linear1.custom_param_name": config(),
"linear1.weight": config(),
"linear2.weight": config(),
}
)

quantize_(m, fqn_config, filter_fn=None)

assert isinstance(m.linear1.custom_param_name.data, torchao_tensor_types)
assert isinstance(m.linear1.weight.data, torchao_tensor_types)
assert isinstance(m.linear2.weight.data, torchao_tensor_types)


common_utils.instantiate_parametrized_tests(TestFqnToConfig)

if __name__ == "__main__":
unittest.main()
9 changes: 5 additions & 4 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,11 @@ def default(self, o):
return [self.encode_value(item) for item in o]

elif isinstance(o, tuple):
raise NotImplementedError(
"Tuples will be serialized as List in JSON, so we recommend to use "
f"Lists instead to avoid surprises. got: {o}"
)
return [self.encode_value(item) for item in o]
# raise NotImplementedError(
# "Tuples will be serialized as List in JSON, so we recommend to use "
# f"Lists instead to avoid surprises. got: {o}"
# )

if isinstance(o, dict):
return {k: self.encode_value(v) for k, v in o.items()}
Expand Down
3 changes: 3 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ def _(func, types, args, kwargs):
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if type(self) is torch.Tensor and isinstance(src, AffineQuantizedTensor):
func(self, src.dequantize())
return
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
fqn_matches_fqn_config,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
Expand Down Expand Up @@ -221,4 +222,6 @@
"Int4WeightOnlyGPTQQuantizer",
"MultiTensor",
"MultiTensorInputRecorder",
# helper functions
"fqn_matches_fqn_config",
]
44 changes: 25 additions & 19 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
"Int8DynActInt4WeightQuantizer",
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
"ModuleFqnToConfig",
"FqnToConfig",
]

LAYOUT_TO_ZERO_POINT_DOMAIN = {
Expand Down Expand Up @@ -480,15 +481,17 @@ def quantize_(

for module_fqn, module in model.named_modules():
if (
_fqn_matches_fqn_config(module_fqn, config)
fqn_matches_fqn_config(module_fqn, config)
or _module_param_matches_fqn_config(module, module_fqn, config)
or ("_default" in config.fqn_to_config and _is_linear(module))
):
module_name = (
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
module_fqn.rsplit(".", 1)[0] if "." in module_fqn else module_fqn
)
# this replaces inplace, so no need to reassign
_fqn_to_config_handler(module, module_name, config, device)
_fqn_to_config_handler(module, module_name, config)
if device is not None:
module.to(device=device)
return
if isinstance(config, AOBaseConfig):
filter_fn = _is_linear if filter_fn is None else filter_fn
Expand Down Expand Up @@ -1253,17 +1256,22 @@ def _int4_weight_only_quantize_tensor(weight, config):

@register_quantize_module_handler(Int4WeightOnlyConfig)
def _int4_weight_only_transform(
module: torch.nn.Module, config: Int4WeightOnlyConfig
module: torch.nn.Module,
config: Int4WeightOnlyConfig,
*,
parameter_name: str = "weight",
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

assert hasattr(module, "weight"), (
"applying int8 weight only quant requires module to have weight attribute"
assert hasattr(module, parameter_name), (
"applying int8 weight only quant requires module to have {parameter_name} attribute"
+ " but {module} does not have one"
)
new_weight = _int4_weight_only_quantize_tensor(module.weight, config)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
new_weight = _int4_weight_only_quantize_tensor(
getattr(module, parameter_name), config
)
setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False))
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module

Expand Down Expand Up @@ -2315,18 +2323,19 @@ def _intx_weight_only_transform(
*,
custom_scale: Optional[torch.Tensor] = None,
custom_zero_point: Optional[torch.Tensor] = None,
parameter_name="weight",
) -> torch.nn.Module:
assert hasattr(module, "weight"), (
"applying intx weight only quant requires module to have weight attribute"
assert hasattr(module, parameter_name), (
"applying intx weight only quant requires module to have {parameter_name} attribute"
+ " but {module} does not have one"
)
new_weight = _intx_weight_only_quantize_tensor(
module.weight,
getattr(module, parameter_name),
config,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False))

if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
Expand Down Expand Up @@ -2463,14 +2472,15 @@ def __post_init__(self):
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int8WeightOnlyConfig,
Int4WeightOnlyConfig,
IntxWeightOnlyConfig,
}


def _fqn_to_config_handler(
module: torch.nn.Module,
fqn: str,
config: FqnToConfig,
device: Optional[torch.device] = None,
):
"""This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig.

Expand All @@ -2479,17 +2489,13 @@ def _fqn_to_config_handler(
fqn (str): The fully qualified name of the module containing the parameters.
config (FqnToConfig): Configuration object containing regex patterns / fqn mapped
to quantization configurations.
device (Optional[torch.device]): The device to move the module to as part of quantization

Returns:
torch.nn.Module: The modified module with quantized parameters.

Raises:
NotImplementedError: If the quantization configuration is not yet supported for parameter quantization.
"""
if device is not None:
module = module.to(device)

parameter_config_found = False
top_level_params = []
for i, (parameter_name, param) in enumerate(list(module.named_parameters())):
Expand Down Expand Up @@ -2563,7 +2569,7 @@ def _fqn_to_config_handler(
return module


def _fqn_matches_fqn_config(
def fqn_matches_fqn_config(
fqn: str,
config: FqnToConfig,
):
Expand Down Expand Up @@ -2608,7 +2614,7 @@ def _module_param_matches_fqn_config(
for name, param in module.named_parameters():
if name in dir(module):
parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name
if _fqn_matches_fqn_config(parameter_fqn, config):
if fqn_matches_fqn_config(parameter_fqn, config):
return True

return False
Expand Down
15 changes: 8 additions & 7 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,14 @@ def from_hp(
else:
maybe_hp_value_ub_tensor = None
if isinstance(granularity, PerRow):
data, scale = torch.ops.triton.quantize_fp8_row(
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
)
scale_shape = []
for i in range(hp_tensor.ndim):
scale_shape.append(hp_tensor.shape[i] // block_size[i])
scale = scale.reshape(*scale_shape)
with torch.cuda.device(hp_tensor.device):
data, scale = torch.ops.triton.quantize_fp8_row(
hp_tensor, scale_ub=maybe_hp_value_ub_tensor
)
scale_shape = []
for i in range(hp_tensor.ndim):
scale_shape.append(hp_tensor.shape[i] // block_size[i])
scale = scale.reshape(*scale_shape)
else:
assert isinstance(granularity, PerTensor), (
f"Expected per tensor, got {granularity}"
Expand Down
3 changes: 3 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool:
def _(func, types, args, kwargs):
self = args[0]
src = args[1]
if type(self) is torch.Tensor and isinstance(src, TorchAOBaseTensor):
func(self, src.dequantize())
return
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
Expand Down
Loading