Skip to content
57 changes: 55 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gc
import tempfile
import unittest
import warnings
from pathlib import Path

import torch
Expand Down Expand Up @@ -37,6 +38,8 @@
PerGroup,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
Expand Down Expand Up @@ -623,8 +626,8 @@ def test_workflow_e2e_numerics(self, config):
isinstance(
config,
(
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
),
)
and not is_sm_at_least_89()
Expand Down Expand Up @@ -755,6 +758,56 @@ def test_int4wo_cuda_serialization(self):
# load state_dict in cuda
model.load_state_dict(sd, assign=True)

def test_config_deprecation(self):
"""
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
"""
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)

# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Map from deprecated API to the args needed to instantiate it
deprecated_apis_to_args = {
float8_dynamic_activation_float8_weight: (),
float8_static_activation_float8_weight: (torch.randn(3)),
float8_weight_only: (),
fpx_weight_only: (3, 2),
gemlite_uintx_weight_only: (),
int4_dynamic_activation_int4_weight: (),
int4_weight_only: (),
int8_dynamic_activation_int4_weight: (),
int8_dynamic_activation_int8_weight: (),
int8_weight_only: (),
uintx_weight_only: (torch.uint4,),
}

with warnings.catch_warnings(record=True) as _warnings:
# Call each deprecated API twice
for cls, args in deprecated_apis_to_args.items():
cls(*args)
cls(*args)

# Each call should trigger the warning only once
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
for w in _warnings:
self.assertIn(
"is deprecated and will be removed in a future release",
str(w.message),
)


common_utils.instantiate_parametrized_tests(TestQuantFlow)

Expand Down
39 changes: 28 additions & 11 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
to_weight_tensor_with_linear_activation_quantization_metadata,
)
from torchao.utils import (
_ConfigDeprecationWrapper,
_is_fbgemm_genai_gpu_available,
is_MI300,
is_sm_at_least_89,
Expand Down Expand Up @@ -639,7 +640,9 @@ def __post_init__(self):


# for BC
int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
)


@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
Expand Down Expand Up @@ -972,7 +975,9 @@ def __post_init__(self):


# for bc
int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
)


@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
Expand Down Expand Up @@ -1033,7 +1038,9 @@ def __post_init__(self):


# for BC
gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
)


@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
Expand Down Expand Up @@ -1115,7 +1122,7 @@ def __post_init__(self):

# for BC
# TODO maybe change other callsites
int4_weight_only = Int4WeightOnlyConfig
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)


def _int4_weight_only_quantize_tensor(weight, config):
Expand Down Expand Up @@ -1325,7 +1332,7 @@ def __post_init__(self):


# for BC
int8_weight_only = Int8WeightOnlyConfig
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)


def _int8_weight_only_quantize_tensor(weight, config):
Expand Down Expand Up @@ -1486,7 +1493,9 @@ def __post_init__(self):


# for BC
int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
)


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
Expand Down Expand Up @@ -1595,7 +1604,9 @@ def __post_init__(self):


# for BC
float8_weight_only = Float8WeightOnlyConfig
float8_weight_only = _ConfigDeprecationWrapper(
"float8_weight_only", Float8WeightOnlyConfig
)


def _float8_weight_only_quant_tensor(weight, config):
Expand Down Expand Up @@ -1753,7 +1764,9 @@ def __post_init__(self):


# for bc
float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
)


def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
Expand Down Expand Up @@ -1926,7 +1939,9 @@ def __post_init__(self):


# for bc
float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
)


@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
Expand Down Expand Up @@ -2009,7 +2024,9 @@ def __post_init__(self):


# for BC
uintx_weight_only = UIntXWeightOnlyConfig
uintx_weight_only = _ConfigDeprecationWrapper(
"uintx_weight_only", UIntXWeightOnlyConfig
)


@register_quantize_module_handler(UIntXWeightOnlyConfig)
Expand Down Expand Up @@ -2262,7 +2279,7 @@ def __post_init__(self):


# for BC
fpx_weight_only = FPXWeightOnlyConfig
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)


@register_quantize_module_handler(FPXWeightOnlyConfig)
Expand Down
21 changes: 20 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import reduce
from importlib.metadata import version
from math import gcd
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Type

import torch
import torch.nn.utils.parametrize as parametrize
Expand Down Expand Up @@ -433,6 +433,25 @@ def __eq__(self, other):
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")


class _ConfigDeprecationWrapper:
"""
A deprecation wrapper that directs users from a deprecated "config function"
(e.g. `int4_weight_only`) to the replacement config class.
"""

def __init__(self, deprecated_name: str, config_cls: Type):
self.deprecated_name = deprecated_name
self.config_cls = config_cls

def __call__(self, *args, **kwargs):
warnings.warn(
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
f" quantize_(model, {self.config_cls.__name__}(...))"
)
return self.config_cls(*args, **kwargs)


"""
Helper function for implementing aten op or torch function dispatch
and dispatching to these implementations.
Expand Down
Loading