From 7f0830c2342440a33160459c5d8b006f67702e2c Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 3 Sep 2025 17:49:44 -0700 Subject: [PATCH] Improve QAT fp8-int4 numerics **Summary:** This commit improved the prepare vs convert SQNR of fp8-int4 QAT from 12 to 22. This is achieved by mimicking the numerics of the target FBGEMM fp8-int4 kernel more closely. In particular, FBGEMM first quantizes the weights to fp8, and then uses max abs values to compute the scale, which is significantly different from what torchao's quant primitives do. **Test Plan:** ``` python test/quantization/test_qat.py -k test_fbgemm_fp8_primitives python test/quantization/test_qat.py -k test_fbgemm_int4_primitives python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ``` --- test/quantization/test_qat.py | 141 +++++++++++++++++- .../quantization/qat/fake_quantize_config.py | 24 ++- torchao/quantization/qat/fake_quantizer.py | 66 +++++++- torchao/quantization/quant_primitives.py | 19 ++- 4 files changed, 236 insertions(+), 14 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index d67b922f41..f8e07c8954 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -49,6 +49,7 @@ ) from torchao.quantization.qat.fake_quantize_config import ( Float8FakeQuantizeConfig, + Int4WeightPreshuffledFakeQuantizeConfig, IntxFakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( @@ -1929,7 +1930,7 @@ def test_quantize_api_fp8_int4(self): """ self._test_quantize_api_against_ptq( Float8DynamicActivationInt4WeightConfig(), - target_prepare_sqnr=12, + target_prepare_sqnr=22, target_convert_sqnr=float("inf"), ) @@ -1950,6 +1951,19 @@ def test_quantize_api_int4(self, version: int): target_convert_sqnr=float("inf"), ) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_quantize_api_int8_int4(self): + """ + Test the following: + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert")) + """ + self._test_quantize_api_against_ptq( + Int8DynamicActivationInt4WeightConfig(group_size=32), + target_prepare_sqnr=30, + target_convert_sqnr=float("inf"), + ) + def test_infer_fp8_int4_config(self): """ Test that fake quantize configs are correctly inferred from @@ -1964,10 +1978,9 @@ def test_infer_fp8_int4_config(self): self.assertIsInstance(act_config, Float8FakeQuantizeConfig) self.assertEqual(act_config.dtype, torch.float8_e4m3fn) self.assertIsInstance(act_config.granularity, PerRow) - self.assertIsInstance(weight_config, IntxFakeQuantizeConfig) - self.assertEqual(weight_config.dtype, torch.int4) + self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig) self.assertEqual(weight_config.group_size, 128) - self.assertTrue(weight_config.is_symmetric) + self.assertEqual(weight_config.activation_dtype, torch.float8_e4m3fn) def test_infer_int4_weight_only_config(self): """ @@ -2033,6 +2046,126 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): sqnr = compute_error(out, baseline_out).item() self.assertGreater(sqnr, 24) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + def test_fbgemm_fp8_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row + (2) Our reference QAT version in `Float8FakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import quantize_fp8_row + + from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, + ) + + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + + # (1) Just call `quantize_fp8_row` + (q1, scale1) = quantize_fp8_row(x1) + + # (2) Our reference implementation for QAT without the dequantize + scale2 = _choose_scale_float8( + x2, + (1, x2.shape[-1]), + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + q2 = _quantize_affine_float8(x2, scale2, torch.float8_e4m3fn) + sqnr = compute_error(q1.to(torch.float32), q2.to(torch.float32)) + scale_sqnr = compute_error( + scale1.to(torch.float32).flatten(), + scale2.to(torch.float32).flatten(), + ) + self.assertGreater(sqnr, 40) + self.assertGreater(scale_sqnr, 50) + + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf( + not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0" + ) + def test_fbgemm_int4_preshuffled_primitives(self): + """ + Compare numerics between: + (1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle + (2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer` + """ + from fbgemm_gpu.experimental.gen_ai.quantize import ( + int4_row_quantize, + pack_int4, + quantize_fp8_row, + quantize_int4_preshuffle, + ) + + from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, + _quantize_affine_no_dtype_cast, + ) + + group_size = 128 + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x2 = copy.deepcopy(x1) + x3 = copy.deepcopy(x1) + + # (1) Just call `quantize_int4_preshuffle` + (q1, (scale1, _)) = quantize_int4_preshuffle(x1, group_size, dtype="fp8") + + # (2) Call `quantize_int4_preshuffle` but skip packing and shuffling + (q2, _) = quantize_fp8_row(x2) + (q2, scale2) = int4_row_quantize(q2, group_size) + + # (3) Reference implementation for QAT without the dequantize + fp8_scale = _choose_scale_float8( + x3, + (1, x3.shape[-1]), + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + x3_fp8 = _quantize_affine_float8(x3, fp8_scale, torch.float8_e4m3fn) + x3_fp8 = x3_fp8.to(torch.float32) + x3_fp8_grouped = x3_fp8.view(x3_fp8.shape[0], -1, group_size) + max_abs = torch.amax(torch.abs(x3_fp8_grouped), dim=-1, keepdim=False) + scale = torch.clamp(max_abs / 8, min=1e-6) + zero_point = torch.zeros_like(scale) + q3 = _quantize_affine_no_dtype_cast( + x3_fp8, + (1, group_size), + scale, + zero_point, + quant_min=-8, + quant_max=7, + ) + scale3 = scale + + def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + t = pack_int4(t.to(torch.int8)) + return torch.ops.fbgemm.preshuffle_i4(t, scale.to(torch.float8_e4m3fn))[0] + + # First, sanity check that shuffle_and_pack(q2) == q1 + torch.testing.assert_close(q1, shuffle_and_pack(q2, scale2), atol=0, rtol=0) + + # Now check q2 vs q3 with and without shuffle + sqnr_q2_q3 = compute_error(q2.to(torch.float32), q3.to(torch.float32)) + sqnr_q2_q3_preshuffle = compute_error( + shuffle_and_pack(q2, scale2).to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + ) + self.assertGreater(sqnr_q2_q3, 32) + self.assertGreater(sqnr_q2_q3_preshuffle, 32) + + # Now check shuffle_and_pack(q3) vs q1 + sqnr_q1_q3_preshuffle = compute_error( + q1.to(torch.float32), + shuffle_and_pack(q3, scale3).to(torch.float32), + ) + self.assertGreater(sqnr_q1_q3_preshuffle, 32) + instantiate_parametrized_tests(TestQAT) diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 2999af5264..7bc1e69c85 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -77,6 +77,25 @@ def __post_init__(self): ) +@dataclass +class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel: + torch.ops.fbgemm.f8i4bf16_shuffled + + Currently this only supports float8 input activations. It is expected to be used in conjunction with + :class:`~torchao.quantization.Float8DynamicActivationInt4WeightConfig`. In the future, we may extend + this to support bfloat16 as well. + """ + + group_size: int = 128 + activation_dtype: torch.dtype = e4m3_dtype + + def __post_init__(self): + if self.activation_dtype != e4m3_dtype: + raise ValueError(f"Only {e4m3_dtype} activation is supported currently") + + @dataclass class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): """ @@ -404,10 +423,9 @@ def _infer_fake_quantize_configs( dtype=torch.float8_e4m3fn, granularity=PerRow(), ) - weight_config = IntxFakeQuantizeConfig( - dtype=torch.int4, + weight_config = Int4WeightPreshuffledFakeQuantizeConfig( group_size=128, - is_symmetric=True, + activation_dtype=e4m3_dtype, ) elif isinstance(base_config, NVFP4InferenceConfig): # Note: today the PTQ config does not allow the user to specify diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 7bf27f4719..8a63a0d0ad 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -11,6 +11,7 @@ from torchao.quantization.granularity import ( PerAxis, PerGroup, + PerRow, PerToken, ) from torchao.quantization.observer import get_block_size @@ -20,6 +21,7 @@ MappingType, _choose_scale_float8, _dequantize_affine_float8, + _fake_quantize_affine, _quantize_affine_float8, _Round, choose_qparams_affine, @@ -33,6 +35,7 @@ from .fake_quantize_config import ( FakeQuantizeConfigBase, Float8FakeQuantizeConfig, + Int4WeightPreshuffledFakeQuantizeConfig, IntxFakeQuantizeConfig, ) from .utils import ( @@ -65,6 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": if isinstance(config, IntxFakeQuantizeConfig): return IntxFakeQuantizer(config) + elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig): + return Int4WeightPreshuffledFakeQuantizer(config) elif isinstance(config, Float8FakeQuantizeConfig): return Float8FakeQuantizer(config) elif isinstance(config, NVFP4FakeQuantizeConfig): @@ -93,13 +98,68 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: hp_value_lb=self.config.hp_value_lb, hp_value_ub=self.config.hp_value_ub, ) - q = _quantize_affine_float8( - x, scale, self.config.dtype, cast_to_float8_dtype=False - ) + q = _quantize_affine_float8(x, scale, self.config.dtype) dq = _dequantize_affine_float8(q, scale, original_dtype) return dq +class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying int4 fake quantization to a weight tensor, + targeting the following FBGEMM kernel: + torch.ops.fbgemm.f8i4bf16_shuffled + """ + + def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig): + super().__init__() + self.config = config + torch._C._log_api_usage_once( + "torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer" + ) + + def forward(self, w: torch.Tensor) -> torch.Tensor: + """ + Apply int4 fake quantization to the weight tensor, using the following as a reference: + https://github.com/pytorch/FBGEMM/blob/80cc48c4b2b7fcc579e53211fc8715a8592cbd2c/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L112 + + Currently, we expect the activations to always be rowwise float8. + """ + assert w.dim() == 2 + assert self.config.activation_dtype == torch.float8_e4m3fn + + # First quantize weights to fp8 per row + # This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.quantize_fp8_row + per_row_block_size = get_block_size(w.shape, PerRow()) + fp8_scale = _choose_scale_float8( + w, + per_row_block_size, + torch.float8_e4m3fn, + hp_value_lb=1e-12, + ) + w_fp8 = _quantize_affine_float8(w, fp8_scale, torch.float8_e4m3fn) + w_fp8 = _dequantize_affine_float8(w_fp8, fp8_scale, w.dtype) + + # Now quantize to int4 per group + # This simulates the numerics of fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize + eps = 1e-6 + fbgemm_scale_quant_max = 8 + w_fp8_grouped = w_fp8.view(w_fp8.shape[0], -1, self.config.group_size) + max_abs = torch.amax(torch.abs(w_fp8_grouped), dim=-1, keepdim=False) + scale = torch.clamp(max_abs / fbgemm_scale_quant_max, min=eps) + zero_point = torch.zeros_like(scale) + per_group_block_size = (1, self.config.group_size) + fq = _fake_quantize_affine( + w_fp8, + per_group_block_size, + scale, + zero_point, + quant_dtype=torch.int8, + quant_min=-8, + quant_max=7, + ) + return fq.to(w.dtype) + + class IntxFakeQuantizer(FakeQuantizerBase): """ Generic module for applying integer fake quantization to a tensor, as specified in the config. diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index c118e0b4ce..6298344745 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -219,6 +219,20 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor: return gy +class _RoundToFloat8(torch.autograd.Function): + """ + Implementation of `tensor.to(float8_dtype)` with backward STE. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, float8_dtype: torch.dtype) -> torch.Tensor: + return x.to(float8_dtype) + + @staticmethod + def backward(ctx, gy: torch.Tensor) -> torch.Tensor: + return gy, None + + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also verify bounds. @@ -2275,7 +2289,6 @@ def _quantize_affine_float8( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn, - cast_to_float8_dtype: bool = True, ) -> torch.Tensor: """ Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. @@ -2288,9 +2301,7 @@ def _quantize_affine_float8( tensor_scaled = tensor_fp32 / scale_expanded max_value = torch.finfo(float8_dtype).max tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) - if cast_to_float8_dtype: - tensor_clamped = tensor_clamped.to(float8_dtype) - return tensor_clamped + return _RoundToFloat8.apply(tensor_clamped, float8_dtype) # TODO: don't register as custom op?