diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e1c6471b17..909ddd1842 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1122,6 +1122,41 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + from torchao.quantization.quant_api import ( + CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, + ) + + @common_utils.parametrize("config", CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS) + def test_fqn_to_config_supported_param_configs(self, config): + """Test that all supported parameter configs are in FqnToConfig.""" + + from torchao.utils import ( + TorchAOBaseTensor, + ) + + torchao_tensor_types = (TorchAOBaseTensor, AffineQuantizedTensor) + m = ToyLinearModel(m=128, k=128, n=128) + m.linear1.register_parameter( + "custom_param_name", torch.nn.Parameter(torch.randn(m.linear1.weight.shape)) + ) + m = m.cuda().bfloat16() + + fqn_config = FqnToConfig( + { + "linear1.custom_param_name": config(), + "linear1.weight": config(), + "linear2.weight": config(), + } + ) + + quantize_(m, fqn_config, filter_fn=None) + + assert isinstance(m.linear1.custom_param_name.data, torchao_tensor_types) + assert isinstance(m.linear1.weight.data, torchao_tensor_types) + assert isinstance(m.linear2.weight.data, torchao_tensor_types) + + +common_utils.instantiate_parametrized_tests(TestFqnToConfig) if __name__ == "__main__": unittest.main() diff --git a/torchao/core/config.py b/torchao/core/config.py index 330e6a42af..421dee52b8 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -144,10 +144,11 @@ def default(self, o): return [self.encode_value(item) for item in o] elif isinstance(o, tuple): - raise NotImplementedError( - "Tuples will be serialized as List in JSON, so we recommend to use " - f"Lists instead to avoid surprises. got: {o}" - ) + return [self.encode_value(item) for item in o] + # raise NotImplementedError( + # "Tuples will be serialized as List in JSON, so we recommend to use " + # f"Lists instead to avoid surprises. got: {o}" + # ) if isinstance(o, dict): return {k: self.encode_value(v) for k, v in o.items()} diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ffadece729..e033e9d8b3 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -456,6 +456,9 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): self = args[0] src = args[1] + if type(self) is torch.Tensor and isinstance(src, AffineQuantizedTensor): + func(self, src.dequantize()) + return if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 7459b2504c..ba7f38facd 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -69,6 +69,7 @@ float8_static_activation_float8_weight, float8_weight_only, fpx_weight_only, + fqn_matches_fqn_config, gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, @@ -221,4 +222,6 @@ "Int4WeightOnlyGPTQQuantizer", "MultiTensor", "MultiTensorInputRecorder", + # helper functions + "fqn_matches_fqn_config", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ddeb8c7ca6..ce9303a1f1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -161,6 +161,7 @@ "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", + "FqnToConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -480,15 +481,17 @@ def quantize_( for module_fqn, module in model.named_modules(): if ( - _fqn_matches_fqn_config(module_fqn, config) + fqn_matches_fqn_config(module_fqn, config) or _module_param_matches_fqn_config(module, module_fqn, config) or ("_default" in config.fqn_to_config and _is_linear(module)) ): module_name = ( - module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn + module_fqn.rsplit(".", 1)[0] if "." in module_fqn else module_fqn ) # this replaces inplace, so no need to reassign - _fqn_to_config_handler(module, module_name, config, device) + _fqn_to_config_handler(module, module_name, config) + if device is not None: + module.to(device=device) return if isinstance(config, AOBaseConfig): filter_fn = _is_linear if filter_fn is None else filter_fn @@ -1253,17 +1256,22 @@ def _int4_weight_only_quantize_tensor(weight, config): @register_quantize_module_handler(Int4WeightOnlyConfig) def _int4_weight_only_transform( - module: torch.nn.Module, config: Int4WeightOnlyConfig + module: torch.nn.Module, + config: Int4WeightOnlyConfig, + *, + parameter_name: str = "weight", ) -> torch.nn.Module: if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( - "applying int8 weight only quant requires module to have weight attribute" + assert hasattr(module, parameter_name), ( + "applying int8 weight only quant requires module to have {parameter_name} attribute" + " but {module} does not have one" ) - new_weight = _int4_weight_only_quantize_tensor(module.weight, config) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + new_weight = _int4_weight_only_quantize_tensor( + getattr(module, parameter_name), config + ) + setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2315,18 +2323,19 @@ def _intx_weight_only_transform( *, custom_scale: Optional[torch.Tensor] = None, custom_zero_point: Optional[torch.Tensor] = None, + parameter_name="weight", ) -> torch.nn.Module: - assert hasattr(module, "weight"), ( - "applying intx weight only quant requires module to have weight attribute" + assert hasattr(module, parameter_name), ( + "applying intx weight only quant requires module to have {parameter_name} attribute" + " but {module} does not have one" ) new_weight = _intx_weight_only_quantize_tensor( - module.weight, + getattr(module, parameter_name), config, custom_scale=custom_scale, custom_zero_point=custom_zero_point, ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) @@ -2463,6 +2472,8 @@ def __post_init__(self): Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int8WeightOnlyConfig, + Int4WeightOnlyConfig, + IntxWeightOnlyConfig, } @@ -2470,7 +2481,6 @@ def _fqn_to_config_handler( module: torch.nn.Module, fqn: str, config: FqnToConfig, - device: Optional[torch.device] = None, ): """This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig. @@ -2479,7 +2489,6 @@ def _fqn_to_config_handler( fqn (str): The fully qualified name of the module containing the parameters. config (FqnToConfig): Configuration object containing regex patterns / fqn mapped to quantization configurations. - device (Optional[torch.device]): The device to move the module to as part of quantization Returns: torch.nn.Module: The modified module with quantized parameters. @@ -2487,9 +2496,6 @@ def _fqn_to_config_handler( Raises: NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. """ - if device is not None: - module = module.to(device) - parameter_config_found = False top_level_params = [] for i, (parameter_name, param) in enumerate(list(module.named_parameters())): @@ -2563,7 +2569,7 @@ def _fqn_to_config_handler( return module -def _fqn_matches_fqn_config( +def fqn_matches_fqn_config( fqn: str, config: FqnToConfig, ): @@ -2608,7 +2614,7 @@ def _module_param_matches_fqn_config( for name, param in module.named_parameters(): if name in dir(module): parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name - if _fqn_matches_fqn_config(parameter_fqn, config): + if fqn_matches_fqn_config(parameter_fqn, config): return True return False diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 3581cb619c..984eba59ca 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -208,13 +208,14 @@ def from_hp( else: maybe_hp_value_ub_tensor = None if isinstance(granularity, PerRow): - data, scale = torch.ops.triton.quantize_fp8_row( - hp_tensor, scale_ub=maybe_hp_value_ub_tensor - ) - scale_shape = [] - for i in range(hp_tensor.ndim): - scale_shape.append(hp_tensor.shape[i] // block_size[i]) - scale = scale.reshape(*scale_shape) + with torch.cuda.device(hp_tensor.device): + data, scale = torch.ops.triton.quantize_fp8_row( + hp_tensor, scale_ub=maybe_hp_value_ub_tensor + ) + scale_shape = [] + for i in range(hp_tensor.ndim): + scale_shape.append(hp_tensor.shape[i] // block_size[i]) + scale = scale.reshape(*scale_shape) else: assert isinstance(granularity, PerTensor), ( f"Expected per tensor, got {granularity}" diff --git a/torchao/utils.py b/torchao/utils.py index 02013c5197..95c59aa8de 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -572,6 +572,9 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: def _(func, types, args, kwargs): self = args[0] src = args[1] + if type(self) is torch.Tensor and isinstance(src, TorchAOBaseTensor): + func(self, src.dequantize()) + return if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: