Skip to content

Commit 3537a22

Browse files
committed
add test
1 parent 90502df commit 3537a22

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

test/quantization/test_quant_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,41 @@ def reset_memory():
11221122
assert param.is_cuda
11231123
self.assertLess(memory_streaming, memory_baseline)
11241124

1125+
from torchao.quantization.quant_api import (
1126+
CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS,
1127+
)
1128+
1129+
@common_utils.parametrize("config", CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS)
1130+
def test_fqn_to_config_supported_param_configs(self, config):
1131+
"""Test that all supported parameter configs are in FqnToConfig."""
1132+
1133+
from torchao.utils import (
1134+
TorchAOBaseTensor,
1135+
)
1136+
1137+
torchao_tensor_types = (TorchAOBaseTensor, AffineQuantizedTensor)
1138+
m = ToyLinearModel(m=128, k=128, n=128)
1139+
m.linear1.register_parameter(
1140+
"custom_param_name", torch.nn.Parameter(torch.randn(m.linear1.weight.shape))
1141+
)
1142+
m = m.cuda().bfloat16()
1143+
1144+
fqn_config = FqnToConfig(
1145+
{
1146+
"linear1.custom_param_name": config(),
1147+
"linear1.weight": config(),
1148+
"linear2.weight": config(),
1149+
}
1150+
)
1151+
1152+
quantize_(m, fqn_config, filter_fn=None)
1153+
1154+
assert isinstance(m.linear1.custom_param_name.data, torchao_tensor_types)
1155+
assert isinstance(m.linear1.weight.data, torchao_tensor_types)
1156+
assert isinstance(m.linear2.weight.data, torchao_tensor_types)
1157+
1158+
1159+
common_utils.instantiate_parametrized_tests(TestFqnToConfig)
11251160

11261161
if __name__ == "__main__":
11271162
unittest.main()

0 commit comments

Comments
 (0)