@@ -1122,6 +1122,41 @@ def reset_memory():
11221122 assert param .is_cuda
11231123 self .assertLess (memory_streaming , memory_baseline )
11241124
1125+ from torchao .quantization .quant_api import (
1126+ CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS ,
1127+ )
1128+
1129+ @common_utils .parametrize ("config" , CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS )
1130+ def test_fqn_to_config_supported_param_configs (self , config ):
1131+ """Test that all supported parameter configs are in FqnToConfig."""
1132+
1133+ from torchao .utils import (
1134+ TorchAOBaseTensor ,
1135+ )
1136+
1137+ torchao_tensor_types = (TorchAOBaseTensor , AffineQuantizedTensor )
1138+ m = ToyLinearModel (m = 128 , k = 128 , n = 128 )
1139+ m .linear1 .register_parameter (
1140+ "custom_param_name" , torch .nn .Parameter (torch .randn (m .linear1 .weight .shape ))
1141+ )
1142+ m = m .cuda ().bfloat16 ()
1143+
1144+ fqn_config = FqnToConfig (
1145+ {
1146+ "linear1.custom_param_name" : config (),
1147+ "linear1.weight" : config (),
1148+ "linear2.weight" : config (),
1149+ }
1150+ )
1151+
1152+ quantize_ (m , fqn_config , filter_fn = None )
1153+
1154+ assert isinstance (m .linear1 .custom_param_name .data , torchao_tensor_types )
1155+ assert isinstance (m .linear1 .weight .data , torchao_tensor_types )
1156+ assert isinstance (m .linear2 .weight .data , torchao_tensor_types )
1157+
1158+
1159+ common_utils .instantiate_parametrized_tests (TestFqnToConfig )
11251160
11261161if __name__ == "__main__" :
11271162 unittest .main ()
0 commit comments