Skip to content

Commit dbc89d3

Browse files
authored
Add generic TorchAOTensor extra_repr for nn.Modules (#3328)
* Fix nn.Linear module repr for param quantization Summary: att Test Plan: Reviewers: Subscribers: Tasks: Tags: * quantized_tensor instead of new_weight * update * update * update * update * update test
1 parent 017326a commit dbc89d3

File tree

2 files changed

+94
-7
lines changed

2 files changed

+94
-7
lines changed

test/quantization/test_quant_api.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,54 @@ def test_config_deprecation(self):
853853
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
854854
@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+")
855855
class TestFqnToConfig(TestCase):
856+
def test_fqn_to_config_repr_custom(self):
857+
class TestModule(torch.nn.Module):
858+
def __init__(self):
859+
super().__init__()
860+
self.register_parameter(
861+
"x", torch.nn.Parameter(torch.randn(128, 128, dtype=torch.bfloat16))
862+
)
863+
self.register_parameter(
864+
"y", torch.nn.Parameter(torch.randn(128, 128, dtype=torch.bfloat16))
865+
)
866+
867+
custom_module = TestModule().cuda().eval()
868+
custom_module_config = FqnToConfig(
869+
{
870+
"x": Float8DynamicActivationFloat8WeightConfig(
871+
granularity=PerTensor(),
872+
),
873+
}
874+
)
875+
quantize_(
876+
custom_module,
877+
custom_module_config,
878+
filter_fn=None,
879+
)
880+
assert str(custom_module).startswith("TestModule(x=Float8Tensor(")
881+
assert str(custom_module.x) in str(custom_module)
882+
883+
def test_fqn_to_config_repr_linear(self):
884+
linear_model = ToyLinearModel().to(torch.bfloat16).cuda().eval()
885+
linear_quant_config = FqnToConfig(
886+
{
887+
"linear1.weight": Float8DynamicActivationFloat8WeightConfig(
888+
granularity=PerTensor(),
889+
),
890+
}
891+
)
892+
quantize_(
893+
linear_model,
894+
linear_quant_config,
895+
filter_fn=None,
896+
)
897+
expected_starting_str = (
898+
"Linear(in_features=64, out_features=32, bias=False, weight=Float8Tensor("
899+
)
900+
901+
assert str(linear_model).startswith(expected_starting_str)
902+
assert str(linear_model.linear1.weight) in str(linear_model)
903+
856904
def test_quantize_param_fqn_exact(self):
857905
from transformers import AutoConfig
858906
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe

torchao/quantization/quant_api.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import warnings
2222
from collections import OrderedDict
2323
from dataclasses import dataclass, field
24+
from functools import partial
2425
from typing import Any, Callable, List, Optional, Tuple, Union
2526
from typing import OrderedDict as OrderedDictType
2627

@@ -416,6 +417,19 @@ def _embedding_extra_repr(self):
416417
return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}"
417418

418419

420+
def _module_extra_repr(self, original_extra_repr, parameter_name):
421+
module_torchao_extra_repr = []
422+
423+
original_extra_repr_str = original_extra_repr()
424+
if len(original_extra_repr_str) > 0:
425+
module_torchao_extra_repr.append(original_extra_repr_str)
426+
427+
module_torchao_extra_repr.append(
428+
f"{parameter_name}={_quantization_type(getattr(self, parameter_name))}"
429+
)
430+
return ", ".join(module_torchao_extra_repr)
431+
432+
419433
def _get_linear_subclass_inserter(
420434
constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs
421435
):
@@ -1375,11 +1389,22 @@ def _int8_weight_only_transform(
13751389
"applying int8 weight only quant requires module to have {parameter_name} attribute"
13761390
+ " but {module} does not have one"
13771391
)
1378-
new_weight = _int8_weight_only_quantize_tensor(
1392+
quantized_tensor = _int8_weight_only_quantize_tensor(
13791393
getattr(module, parameter_name), config
13801394
)
1381-
setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False))
1382-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1395+
setattr(
1396+
module,
1397+
parameter_name,
1398+
torch.nn.Parameter(quantized_tensor, requires_grad=False),
1399+
)
1400+
module.extra_repr = types.MethodType(
1401+
partial(
1402+
_module_extra_repr,
1403+
original_extra_repr=module.extra_repr,
1404+
parameter_name=parameter_name,
1405+
),
1406+
module,
1407+
)
13831408
return module
13841409

13851410

@@ -1664,16 +1689,23 @@ def _float8_weight_only_transform(
16641689
if isinstance(module, Float8Linear):
16651690
module = _unwrap_float8_linear(module)
16661691

1667-
new_weight = _float8_weight_only_quant_tensor(
1692+
quantized_tensor = _float8_weight_only_quant_tensor(
16681693
getattr(module, parameter_name), config
16691694
)
16701695

16711696
setattr(
16721697
module,
16731698
parameter_name,
1674-
torch.nn.Parameter(new_weight, requires_grad=False),
1699+
torch.nn.Parameter(quantized_tensor, requires_grad=False),
1700+
)
1701+
module.extra_repr = types.MethodType(
1702+
partial(
1703+
_module_extra_repr,
1704+
original_extra_repr=module.extra_repr,
1705+
parameter_name=parameter_name,
1706+
),
1707+
module,
16751708
)
1676-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
16771709
return module
16781710

16791711

@@ -1946,7 +1978,14 @@ def _float8_dynamic_activation_float8_weight_transform(
19461978
parameter_name,
19471979
torch.nn.Parameter(quantized_tensor, requires_grad=False),
19481980
)
1949-
module.extra_repr = types.MethodType(_linear_extra_repr, module)
1981+
module.extra_repr = types.MethodType(
1982+
partial(
1983+
_module_extra_repr,
1984+
original_extra_repr=module.extra_repr,
1985+
parameter_name=parameter_name,
1986+
),
1987+
module,
1988+
)
19501989
return module
19511990

19521991

0 commit comments

Comments
 (0)