From 3b28e18b63c8cee8211a44ba81007a9f4963816c Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 30 Oct 2025 03:07:55 +0900 Subject: [PATCH 1/7] feat: semi-sparse quantization APIs --- .../workflows/test_semisparse_tensor.py | 123 ++++++++ .../quantize_/workflows/__init__.py | 11 + .../float8/float8_semisparse_tensor.py | 262 ++++++++++++++++++ .../workflows/int8/int8_semisparse_tensor.py | 238 ++++++++++++++++ 4 files changed, 634 insertions(+) create mode 100644 test/quantization/quantize_/workflows/test_semisparse_tensor.py create mode 100644 torchao/prototype/quantization/quantize_/workflows/__init__.py create mode 100644 torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py create mode 100644 torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py diff --git a/test/quantization/quantize_/workflows/test_semisparse_tensor.py b/test/quantization/quantize_/workflows/test_semisparse_tensor.py new file mode 100644 index 0000000000..7c4e459391 --- /dev/null +++ b/test/quantization/quantize_/workflows/test_semisparse_tensor.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal import common_utils + +from torchao.prototype.quantization.quantize_.workflows.float8.float8_semisparse_tensor import ( + Float8SemiSparseTensor, +) +from torchao.prototype.quantization.quantize_.workflows.int8.int8_semisparse_tensor import ( + Int8SemiSparseTensor, +) +from torchao.testing.utils import TorchAOIntegrationTestCase + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests +class TestSemiSparseTensor(TorchAOIntegrationTestCase): + def setUp(self): + super().setUp() + torch.manual_seed(42) + + # Use 512x512 for 2:4 compatibility (multiples of 32) + self.shape = (512, 512) + self.dtype = torch.bfloat16 + self.block_size = [1, 512] + self.weight_fp = torch.randn(*self.shape, dtype=self.dtype, device="cuda") + + @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) + def test_creation_and_shape(self, config): + """Test tensor creation and shape preservation""" + tensor = config.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.shape, self.shape) + self.assertEqual(tensor.original_shape, self.shape) + self.assertEqual(tensor.scale.shape[0], self.shape[0]) + + @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) + def test_sparsity_pattern(self, config): + """Test 2:4 sparsity pattern is maintained""" + tensor = config.from_hp(self.weight_fp, self.block_size) + dequantized = tensor.dequantize() + + # Check 2:4 pattern (skip overall sparsity check for compressed format) + reshaped = dequantized.reshape(-1, 4) + zeros_per_group = (reshaped == 0).sum(dim=1) + valid_groups = (zeros_per_group == 2).sum().item() + total_groups = zeros_per_group.numel() + + self.assertGreaterEqual(valid_groups / total_groups, 0.99) + + def test_int8_quantization_range(self): + """Test Int8 quantization stays in valid range""" + tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.qdata_int8.dtype, torch.int8) + self.assertTrue(torch.all(tensor.qdata_int8 >= -128)) + self.assertTrue(torch.all(tensor.qdata_int8 <= 127)) + + def test_float8_quantization_no_nan(self): + """Test Float8 quantization produces no NaN""" + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.qdata_fp8.dtype, torch.float8_e4m3fn) + self.assertFalse(tensor.qdata_fp8.isnan().any()) + self.assertFalse(tensor.scale.isnan().any()) + + @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) + def test_dequantization_accuracy(self, config): + """Test dequantization error is reasonable""" + tensor = config.from_hp(self.weight_fp, self.block_size) + dequantized = tensor.dequantize() + + # Apply same pruning to original for fair comparison + w_sparse = self.weight_fp.detach().clone() + pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] + w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) + + error = (dequantized - w_sparse).abs().max() + rel_error = error / w_sparse.abs().max() + + # Int8: ~2.0, Float8: ~0.3 + max_error = 2.5 if config == Int8SemiSparseTensor else 0.5 + self.assertLess(error.item(), max_error) + self.assertLess(rel_error.item(), 0.5) + + @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) + def test_invalid_dimensions(self, config): + """Test dimension validation""" + # Not multiple of 32 + invalid_weight = torch.randn(100, 100, dtype=self.dtype, device="cuda") + + with self.assertRaises(ValueError): + config.from_hp(invalid_weight, [1, 100]) + + @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) + def test_cpu_tensor_rejection(self, config): + """Test CPU tensor is rejected""" + cpu_weight = torch.randn(*self.shape, dtype=self.dtype) + + with self.assertRaises(ValueError): + config.from_hp(cpu_weight, self.block_size) + + def test_float8_dtype_selection(self): + """Test Float8 dtype variants""" + tensor_e4m3 = Float8SemiSparseTensor.from_hp( + self.weight_fp, self.block_size, float8_dtype=torch.float8_e4m3fn + ) + self.assertEqual(tensor_e4m3.qdata_fp8.dtype, torch.float8_e4m3fn) + + tensor_e5m2 = Float8SemiSparseTensor.from_hp( + self.weight_fp, self.block_size, float8_dtype=torch.float8_e5m2 + ) + self.assertEqual(tensor_e5m2.qdata_fp8.dtype, torch.float8_e5m2) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torchao/prototype/quantization/quantize_/workflows/__init__.py b/torchao/prototype/quantization/quantize_/workflows/__init__.py new file mode 100644 index 0000000000..97cf27ae9b --- /dev/null +++ b/torchao/prototype/quantization/quantize_/workflows/__init__.py @@ -0,0 +1,11 @@ +from .float8.float8_semisparse_tensor import ( + Float8SemiSparseTensor, +) +from .int8.int8_semisparse_tensor import ( + Int8SemiSparseTensor, +) + +__all__ = [ + "Float8SemiSparseTensor", + "Int8SemiSparseTensor", +] diff --git a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py new file mode 100644 index 0000000000..64acd10f4e --- /dev/null +++ b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torchao.quantization.quant_primitives import ( + _choose_qparams_affine_floatx, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] + +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + """ + float8 quantized tensor with 2:4 semi-structured sparsity layout + + Tensor Attributes: + qdata: float8 data in dense format + qdata_compressed: SparseSemiStructuredTensor (compressed format for matmul) + scale: scale factors for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + original_shape: original uncompressed shape + float8_dtype: float8 dtype variant + """ + + tensor_data_names = ["qdata", "qdata_compressed", "scale"] + tensor_attribute_names = ["block_size", "original_shape"] + optional_tensor_attribute_names = ["dtype"] + + def __new__( + cls: type, + qdata: torch.Tensor, + qdata_compressed: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + qdata_compressed: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + super().__init__() + self.qdata = qdata # dense fp8 for dequantization + self.qdata_compressed = qdata_compressed # compressed for matmul + self.scale = scale + self.block_size = block_size + self.original_shape = original_shape + self.float8_dtype = float8_dtype + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"qdata_shape={self.qdata.shape}, {self.scale=}, " + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + ) + + @property + def qdata_fp8(self): + """For test compatibility""" + return self.qdata + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: list[int], + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + if not w.is_cuda: + raise ValueError("Semi-sparse layout requires CUDA tensors") + + # Verify dimensions are compatible with 2:4 compression + rows, cols = w.shape + if rows % 32 != 0 or cols % 32 != 0: + raise ValueError( + "Tensor dimensions must be multiples of 32 for CUDA sparse compression" + ) + + # Validate block_size + if not all(bs > 0 for bs in block_size): + raise ValueError(f"block_size must be positive, got {block_size}") + + if rows % block_size[0] != 0 or cols % block_size[1] != 0: + raise ValueError( + f"Dimensions {w.shape} must be divisible by block_size {block_size}" + ) + + # Apply 2:4 sparsity pruning + with torch.no_grad(): + w_sparse = w.clone() + + pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] + w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) + + # Quantize to float8 + if float8_dtype == torch.float8_e4m3fn: + ebits, mbits = 4, 3 + max_val = 448.0 + elif float8_dtype == torch.float8_e5m2: + ebits, mbits = 5, 2 + max_val = 57344.0 + else: + raise ValueError(f"Unsupported float8 dtype: {float8_dtype}") + + scale = _choose_qparams_affine_floatx(w_sparse, ebits=ebits, mbits=mbits) + + if scale.isnan().any(): + raise ValueError("Scale contains NaN") + if not (scale > 0).all(): + raise ValueError(f"Scale contains non-positive values: min={scale.min()}") + + scale_expanded = scale.unsqueeze(1) + scaled_data = w_sparse / scale_expanded + scaled_data = scaled_data.clamp(-max_val, max_val) + fp8_data = scaled_data.to(float8_dtype).contiguous() + + if fp8_data.isnan().any(): + raise ValueError("fp8_data contains NaN after quantization") + + # Store fp8 data in both dense and compressed formats + fp8_data_fp16 = fp8_data.to(torch.float16) + from torch.sparse import to_sparse_semi_structured + + fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) + + return cls( + fp8_data, # dense for dequantization + fp8_compressed, # compressed for matmul + scale, + block_size, + original_shape=w.shape, + dtype=w.dtype, + float8_dtype=float8_dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + # Use dense fp8 data + qdata_fp = self.qdata.to(output_dtype) + scale_expanded = self.scale.view(-1, 1).to(output_dtype) + return qdata_fp * scale_expanded + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert isinstance(weight_tensor, Float8SemiSparseTensor) + if not isinstance(activation_tensor, Float8SemiSparseTensor): + raise TypeError( + "Float8SemiSparseTensor requires pre-quantized activations (static quantization only). " + "Activation must be Float8SemiSparseTensor." + ) + assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( + f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" + ) + + # Use compressed data for matmul + x_vals_dense = activation_tensor.qdata_compressed.to_dense() + x_vals_fp8 = x_vals_dense.view(torch.float8_e4m3fn) + x_scales = activation_tensor.scale + + w_vals_dense = weight_tensor.qdata_compressed.to_dense() + w_vals_fp8 = w_vals_dense.view(torch.float8_e4m3fn) + w_scales = weight_tensor.scale + + # Prepare activation for sparse matmul + tmp = x_vals_fp8 + if tmp.dim() > 2: + tmp = tmp.view(-1, tmp.shape[-1]) + row = tmp.shape[0] + + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) + + # Convert weight fp8 to fp16 with scale for matmul + w_scaled = w_vals_fp8.to(torch.float16) * w_scales.unsqueeze(1) + w_sparse_scaled = torch.sparse.to_sparse_semi_structured(w_scaled) + + # Matmul with sparse weight + y = torch.matmul( + tmp_padded.to(torch.bfloat16), w_sparse_scaled.t().to(torch.bfloat16) + ) + y = y[:row, :] + + # Apply activation scale + y = y * x_scales.unsqueeze(1) + + # Reshape to original activation shape + if x_vals_fp8.dim() > 2: + y = y.view(*x_vals_fp8.shape[:-1], y.shape[-1]) + + output_dtype = activation_tensor.dtype + y = y.to(output_dtype).contiguous() + + if bias is not None: + y += bias + return y + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Slice operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Slicing not supported for Float8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Select not supported for Float8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Float8SemiSparseTensor]) diff --git a/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py new file mode 100644 index 0000000000..4bdebe88ab --- /dev/null +++ b/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py @@ -0,0 +1,238 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torchao.quantization.quant_primitives import ( + MappingType, + _maybe_expand_scale_to_tensor_shape, + choose_qparams_affine, + quantize_affine, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Int8SemiSparseTensor"] + +aten = torch.ops.aten + + +class Int8SemiSparseTensor(TorchAOBaseTensor): + """ + int8 quantized tensor with 2:4 semi-structured sparsity layout + + Tensor Attributes: + qdata: SparseSemiStructuredTensor (compressed format, fp16) + scale: scale factors for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + original_shape: original uncompressed shape + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["block_size", "original_shape"] + optional_tensor_attribute_names = ["dtype"] + + def __new__( + cls: type, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.original_shape = original_shape + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"qdata_shape={self.qdata.shape}, {self.scale=}, " + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + ) + + @property + def qdata_int8(self): + return self.qdata.to_dense().to(torch.int8) + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: list[int], + ): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + if not w.is_cuda: + raise ValueError("Semi-sparse layout requires CUDA tensors") + + # Verify dimensions are compatible with 2:4 compression + rows, cols = w.shape + if rows % 32 != 0 or cols % 32 != 0: + raise ValueError( + "Tensor dimensions must be multiples of 32 for CUDA sparse compression" + ) + + # Validate block_size + if not all(bs > 0 for bs in block_size): + raise ValueError(f"block_size must be positive, got {block_size}") + + if rows % block_size[0] != 0 or cols % block_size[1] != 0: + raise ValueError( + f"Dimensions {w.shape} must be divisible by block_size {block_size}" + ) + + # Apply 2:4 sparsity pruning (row-wise for weight matrix) + with torch.no_grad(): + w_sparse = w.clone() + + pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] + w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) + + # Quantize the sparse weight + scale, zero_point = choose_qparams_affine( + input=w_sparse, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w.dtype, + zero_point_dtype=torch.int8, + ) + + int_data = quantize_affine( + w_sparse, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ).contiguous() + + int_data_fp16 = int_data.to(torch.float16) + from torch.sparse import to_sparse_semi_structured + + int_data_compressed = to_sparse_semi_structured(int_data_fp16) + + return cls( + int_data_compressed, + scale, + block_size, + original_shape=w.shape, + dtype=w.dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + # Decompress and convert to int8 + qdata_dense = self.qdata.to_dense().to(torch.int8) + qdata_fp = qdata_dense.to(output_dtype) + + scale_expanded = _maybe_expand_scale_to_tensor_shape( + self.scale, self.original_shape + ) + return qdata_fp * scale_expanded.to(output_dtype) + + +implements = Int8SemiSparseTensor.implements +implements_torch_function = Int8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert isinstance(weight_tensor, Int8SemiSparseTensor) + assert isinstance(activation_tensor, Int8SemiSparseTensor), ( + "Int8SemiSparseTensor requires pre-quantized activations (static quantization only)" + ) + assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( + f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" + ) + + # Extract quantized data + x_vals_dense = activation_tensor.qdata.to_dense().to(torch.int8) + x_scales = activation_tensor.scale + w_vals_dense = weight_tensor.qdata.to_dense().to(torch.int8) + w_scales = weight_tensor.scale + + # Prepare activation for sparse matmul + tmp = x_vals_dense.view(-1, x_vals_dense.shape[-1]) + row, col = tmp.shape + + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) + + # Perform sparse matmul with int8, output as bfloat16 with weight scale + w_scaled = w_vals_dense.to(torch.float16) * w_scales.view(-1, 1) + w_sparse_scaled = torch.sparse.to_sparse_semi_structured(w_scaled) + y_bf16 = torch.matmul( + tmp_padded.to(torch.bfloat16), w_sparse_scaled.t().to(torch.bfloat16) + ) + y_bf16 = y_bf16[:row, :] + + # Apply activation scale + y = (y_bf16 * x_scales.view(-1, 1)).view(*x_vals_dense.shape[:-1], y_bf16.shape[-1]) + + output_dtype = activation_tensor.dtype + y = y.to(output_dtype).contiguous() + + if bias is not None: + y += bias + return y + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Slice operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Slicing not supported for Int8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Select not supported for Int8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +Int8SemiSparseTensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Int8SemiSparseTensor]) From 9a48316f818e218fcfafe613b25baef43cea01d8 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 1 Nov 2025 22:57:57 +0900 Subject: [PATCH 2/7] drop W4A16-INT-CSR, W8A8-FP-CSR APIs --- .../workflows/test_semisparse_tensor.py | 52 +--- .../float8/float8_semisparse_tensor.py | 262 ------------------ .../workflows/int8/int8_semisparse_tensor.py | 21 +- 3 files changed, 22 insertions(+), 313 deletions(-) delete mode 100644 torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py diff --git a/test/quantization/quantize_/workflows/test_semisparse_tensor.py b/test/quantization/quantize_/workflows/test_semisparse_tensor.py index 7c4e459391..dc07e7b218 100644 --- a/test/quantization/quantize_/workflows/test_semisparse_tensor.py +++ b/test/quantization/quantize_/workflows/test_semisparse_tensor.py @@ -9,9 +9,6 @@ import torch from torch.testing._internal import common_utils -from torchao.prototype.quantization.quantize_.workflows.float8.float8_semisparse_tensor import ( - Float8SemiSparseTensor, -) from torchao.prototype.quantization.quantize_.workflows.int8.int8_semisparse_tensor import ( Int8SemiSparseTensor, ) @@ -31,19 +28,17 @@ def setUp(self): self.block_size = [1, 512] self.weight_fp = torch.randn(*self.shape, dtype=self.dtype, device="cuda") - @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) - def test_creation_and_shape(self, config): + def test_creation_and_shape(self): """Test tensor creation and shape preservation""" - tensor = config.from_hp(self.weight_fp, self.block_size) + tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) self.assertEqual(tensor.shape, self.shape) self.assertEqual(tensor.original_shape, self.shape) self.assertEqual(tensor.scale.shape[0], self.shape[0]) - @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) - def test_sparsity_pattern(self, config): + def test_sparsity_pattern(self): """Test 2:4 sparsity pattern is maintained""" - tensor = config.from_hp(self.weight_fp, self.block_size) + tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) dequantized = tensor.dequantize() # Check 2:4 pattern (skip overall sparsity check for compressed format) @@ -62,18 +57,9 @@ def test_int8_quantization_range(self): self.assertTrue(torch.all(tensor.qdata_int8 >= -128)) self.assertTrue(torch.all(tensor.qdata_int8 <= 127)) - def test_float8_quantization_no_nan(self): - """Test Float8 quantization produces no NaN""" - tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) - - self.assertEqual(tensor.qdata_fp8.dtype, torch.float8_e4m3fn) - self.assertFalse(tensor.qdata_fp8.isnan().any()) - self.assertFalse(tensor.scale.isnan().any()) - - @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) - def test_dequantization_accuracy(self, config): + def test_dequantization_accuracy(self): """Test dequantization error is reasonable""" - tensor = config.from_hp(self.weight_fp, self.block_size) + tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) dequantized = tensor.dequantize() # Apply same pruning to original for fair comparison @@ -84,39 +70,23 @@ def test_dequantization_accuracy(self, config): error = (dequantized - w_sparse).abs().max() rel_error = error / w_sparse.abs().max() - # Int8: ~2.0, Float8: ~0.3 - max_error = 2.5 if config == Int8SemiSparseTensor else 0.5 - self.assertLess(error.item(), max_error) + self.assertLess(error.item(), 2.5) self.assertLess(rel_error.item(), 0.5) - @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) - def test_invalid_dimensions(self, config): + def test_invalid_dimensions(self): """Test dimension validation""" # Not multiple of 32 invalid_weight = torch.randn(100, 100, dtype=self.dtype, device="cuda") with self.assertRaises(ValueError): - config.from_hp(invalid_weight, [1, 100]) + Int8SemiSparseTensor.from_hp(invalid_weight, [1, 100]) - @common_utils.parametrize("config", [Int8SemiSparseTensor, Float8SemiSparseTensor]) - def test_cpu_tensor_rejection(self, config): + def test_cpu_tensor_rejection(self): """Test CPU tensor is rejected""" cpu_weight = torch.randn(*self.shape, dtype=self.dtype) with self.assertRaises(ValueError): - config.from_hp(cpu_weight, self.block_size) - - def test_float8_dtype_selection(self): - """Test Float8 dtype variants""" - tensor_e4m3 = Float8SemiSparseTensor.from_hp( - self.weight_fp, self.block_size, float8_dtype=torch.float8_e4m3fn - ) - self.assertEqual(tensor_e4m3.qdata_fp8.dtype, torch.float8_e4m3fn) - - tensor_e5m2 = Float8SemiSparseTensor.from_hp( - self.weight_fp, self.block_size, float8_dtype=torch.float8_e5m2 - ) - self.assertEqual(tensor_e5m2.qdata_fp8.dtype, torch.float8_e5m2) + Int8SemiSparseTensor.from_hp(cpu_weight, self.block_size) if __name__ == "__main__": diff --git a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py deleted file mode 100644 index 64acd10f4e..0000000000 --- a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch - -from torchao.quantization.quant_primitives import ( - _choose_qparams_affine_floatx, -) -from torchao.utils import TorchAOBaseTensor - -__all__ = ["Float8SemiSparseTensor"] - -aten = torch.ops.aten - - -class Float8SemiSparseTensor(TorchAOBaseTensor): - """ - float8 quantized tensor with 2:4 semi-structured sparsity layout - - Tensor Attributes: - qdata: float8 data in dense format - qdata_compressed: SparseSemiStructuredTensor (compressed format for matmul) - scale: scale factors for dequantization - - Non-Tensor Attributes: - block_size: block size for quantization granularity - original_shape: original uncompressed shape - float8_dtype: float8 dtype variant - """ - - tensor_data_names = ["qdata", "qdata_compressed", "scale"] - tensor_attribute_names = ["block_size", "original_shape"] - optional_tensor_attribute_names = ["dtype"] - - def __new__( - cls: type, - qdata: torch.Tensor, - qdata_compressed: torch.Tensor, - scale: torch.Tensor, - block_size: list[int], - original_shape: tuple[int, ...], - dtype=None, - float8_dtype: torch.dtype = torch.float8_e4m3fn, - ): - kwargs = { - "device": qdata.device, - "dtype": dtype or scale.dtype, - "requires_grad": False, - } - return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) - - def __init__( - self, - qdata: torch.Tensor, - qdata_compressed: torch.Tensor, - scale: torch.Tensor, - block_size: list[int], - original_shape: tuple[int, ...], - dtype=None, - float8_dtype: torch.dtype = torch.float8_e4m3fn, - ): - super().__init__() - self.qdata = qdata # dense fp8 for dequantization - self.qdata_compressed = qdata_compressed # compressed for matmul - self.scale = scale - self.block_size = block_size - self.original_shape = original_shape - self.float8_dtype = float8_dtype - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"qdata_shape={self.qdata.shape}, {self.scale=}, " - f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" - ) - - @property - def qdata_fp8(self): - """For test compatibility""" - return self.qdata - - @classmethod - def from_hp( - cls, - w: torch.Tensor, - block_size: list[int], - float8_dtype: torch.dtype = torch.float8_e4m3fn, - ): - if w.dim() != 2 or len(block_size) != 2: - raise ValueError("Expected 2D tensor and block_size length 2") - - if not w.is_cuda: - raise ValueError("Semi-sparse layout requires CUDA tensors") - - # Verify dimensions are compatible with 2:4 compression - rows, cols = w.shape - if rows % 32 != 0 or cols % 32 != 0: - raise ValueError( - "Tensor dimensions must be multiples of 32 for CUDA sparse compression" - ) - - # Validate block_size - if not all(bs > 0 for bs in block_size): - raise ValueError(f"block_size must be positive, got {block_size}") - - if rows % block_size[0] != 0 or cols % block_size[1] != 0: - raise ValueError( - f"Dimensions {w.shape} must be divisible by block_size {block_size}" - ) - - # Apply 2:4 sparsity pruning - with torch.no_grad(): - w_sparse = w.clone() - - pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] - w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) - - # Quantize to float8 - if float8_dtype == torch.float8_e4m3fn: - ebits, mbits = 4, 3 - max_val = 448.0 - elif float8_dtype == torch.float8_e5m2: - ebits, mbits = 5, 2 - max_val = 57344.0 - else: - raise ValueError(f"Unsupported float8 dtype: {float8_dtype}") - - scale = _choose_qparams_affine_floatx(w_sparse, ebits=ebits, mbits=mbits) - - if scale.isnan().any(): - raise ValueError("Scale contains NaN") - if not (scale > 0).all(): - raise ValueError(f"Scale contains non-positive values: min={scale.min()}") - - scale_expanded = scale.unsqueeze(1) - scaled_data = w_sparse / scale_expanded - scaled_data = scaled_data.clamp(-max_val, max_val) - fp8_data = scaled_data.to(float8_dtype).contiguous() - - if fp8_data.isnan().any(): - raise ValueError("fp8_data contains NaN after quantization") - - # Store fp8 data in both dense and compressed formats - fp8_data_fp16 = fp8_data.to(torch.float16) - from torch.sparse import to_sparse_semi_structured - - fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) - - return cls( - fp8_data, # dense for dequantization - fp8_compressed, # compressed for matmul - scale, - block_size, - original_shape=w.shape, - dtype=w.dtype, - float8_dtype=float8_dtype, - ) - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - # Use dense fp8 data - qdata_fp = self.qdata.to(output_dtype) - scale_expanded = self.scale.view(-1, 1).to(output_dtype) - return qdata_fp * scale_expanded - - -implements = Float8SemiSparseTensor.implements -implements_torch_function = Float8SemiSparseTensor.implements_torch_function - - -@implements(aten.linear.default) -@implements_torch_function(torch.nn.functional.linear) -def _(func, types, args, kwargs): - activation_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - - assert isinstance(weight_tensor, Float8SemiSparseTensor) - if not isinstance(activation_tensor, Float8SemiSparseTensor): - raise TypeError( - "Float8SemiSparseTensor requires pre-quantized activations (static quantization only). " - "Activation must be Float8SemiSparseTensor." - ) - assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( - f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" - ) - - # Use compressed data for matmul - x_vals_dense = activation_tensor.qdata_compressed.to_dense() - x_vals_fp8 = x_vals_dense.view(torch.float8_e4m3fn) - x_scales = activation_tensor.scale - - w_vals_dense = weight_tensor.qdata_compressed.to_dense() - w_vals_fp8 = w_vals_dense.view(torch.float8_e4m3fn) - w_scales = weight_tensor.scale - - # Prepare activation for sparse matmul - tmp = x_vals_fp8 - if tmp.dim() > 2: - tmp = tmp.view(-1, tmp.shape[-1]) - row = tmp.shape[0] - - from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - - tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) - - # Convert weight fp8 to fp16 with scale for matmul - w_scaled = w_vals_fp8.to(torch.float16) * w_scales.unsqueeze(1) - w_sparse_scaled = torch.sparse.to_sparse_semi_structured(w_scaled) - - # Matmul with sparse weight - y = torch.matmul( - tmp_padded.to(torch.bfloat16), w_sparse_scaled.t().to(torch.bfloat16) - ) - y = y[:row, :] - - # Apply activation scale - y = y * x_scales.unsqueeze(1) - - # Reshape to original activation shape - if x_vals_fp8.dim() > 2: - y = y.view(*x_vals_fp8.shape[:-1], y.shape[-1]) - - output_dtype = activation_tensor.dtype - y = y.to(output_dtype).contiguous() - - if bias is not None: - y += bias - return y - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Slice operation - not supported for compressed sparse format""" - # TODO: Build this tensor utility operation - raise NotImplementedError( - "Slicing not supported for Float8SemiSparseTensor. " - "Decompress first using dequantize() if needed." - ) - - -@implements(aten.select.int) -def _(func, types, args, kwargs): - """Select operation - not supported for compressed sparse format""" - # TODO: Build this tensor utility operation - raise NotImplementedError( - "Select not supported for Float8SemiSparseTensor. " - "Decompress first using dequantize() if needed." - ) - - -Float8SemiSparseTensor.__module__ = "torchao.quantization" -torch.serialization.add_safe_globals([Float8SemiSparseTensor]) diff --git a/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py index 4bdebe88ab..f31713b191 100644 --- a/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py +++ b/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py @@ -174,26 +174,27 @@ def _(func, types, args, kwargs): ) assert isinstance(weight_tensor, Int8SemiSparseTensor) - assert isinstance(activation_tensor, Int8SemiSparseTensor), ( - "Int8SemiSparseTensor requires pre-quantized activations (static quantization only)" - ) assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" ) - # Extract quantized data - x_vals_dense = activation_tensor.qdata.to_dense().to(torch.int8) - x_scales = activation_tensor.scale + # Dynamic quantization: Activation scale in runtime + original_shape = activation_tensor.shape + x_flat = activation_tensor.view(-1, activation_tensor.shape[-1]) + x_scales = x_flat.abs().max(dim=-1, keepdim=True)[0] / 127.0 + x_scales = x_scales.clamp(min=1e-5) + x_vals = (x_flat / x_scales).round().clamp(-128, 127).to(torch.int8) + + # Weight quantization after activation quantization w_vals_dense = weight_tensor.qdata.to_dense().to(torch.int8) w_scales = weight_tensor.scale # Prepare activation for sparse matmul - tmp = x_vals_dense.view(-1, x_vals_dense.shape[-1]) - row, col = tmp.shape + row, col = x_vals.shape from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(x_vals) # Perform sparse matmul with int8, output as bfloat16 with weight scale w_scaled = w_vals_dense.to(torch.float16) * w_scales.view(-1, 1) @@ -204,7 +205,7 @@ def _(func, types, args, kwargs): y_bf16 = y_bf16[:row, :] # Apply activation scale - y = (y_bf16 * x_scales.view(-1, 1)).view(*x_vals_dense.shape[:-1], y_bf16.shape[-1]) + y = (y_bf16 * x_scales).view(*original_shape[:-1], y_bf16.shape[-1]) output_dtype = activation_tensor.dtype y = y.to(output_dtype).contiguous() From ea18e893e54dbd6af69f06e3884e7faf44827ea7 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 2 Nov 2025 14:11:39 +0900 Subject: [PATCH 3/7] drop W8A8-INT-CSR, pick W8A8-FP-CSR --- .../float8}/test_semisparse_tensor.py | 24 +- .../quantize_/workflows/__init__.py | 4 - .../float8/float8_semisparse_tensor.py | 262 ++++++++++++++++++ 3 files changed, 270 insertions(+), 20 deletions(-) rename test/{quantization/quantize_/workflows => prototype/quantization/quantize_/float8}/test_semisparse_tensor.py (74%) create mode 100644 torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py diff --git a/test/quantization/quantize_/workflows/test_semisparse_tensor.py b/test/prototype/quantization/quantize_/float8/test_semisparse_tensor.py similarity index 74% rename from test/quantization/quantize_/workflows/test_semisparse_tensor.py rename to test/prototype/quantization/quantize_/float8/test_semisparse_tensor.py index dc07e7b218..f3c2b673b8 100644 --- a/test/quantization/quantize_/workflows/test_semisparse_tensor.py +++ b/test/prototype/quantization/quantize_/float8/test_semisparse_tensor.py @@ -9,15 +9,15 @@ import torch from torch.testing._internal import common_utils -from torchao.prototype.quantization.quantize_.workflows.int8.int8_semisparse_tensor import ( - Int8SemiSparseTensor, +from torchao.prototype.quantization.quantize_.workflows.float8.float8_semisparse_tensor import ( + Float8SemiSparseTensor, ) from torchao.testing.utils import TorchAOIntegrationTestCase @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.instantiate_parametrized_tests -class TestSemiSparseTensor(TorchAOIntegrationTestCase): +class TestFloat8SemiSparseTensor(TorchAOIntegrationTestCase): def setUp(self): super().setUp() torch.manual_seed(42) @@ -30,7 +30,7 @@ def setUp(self): def test_creation_and_shape(self): """Test tensor creation and shape preservation""" - tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) self.assertEqual(tensor.shape, self.shape) self.assertEqual(tensor.original_shape, self.shape) @@ -38,7 +38,7 @@ def test_creation_and_shape(self): def test_sparsity_pattern(self): """Test 2:4 sparsity pattern is maintained""" - tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) dequantized = tensor.dequantize() # Check 2:4 pattern (skip overall sparsity check for compressed format) @@ -49,17 +49,9 @@ def test_sparsity_pattern(self): self.assertGreaterEqual(valid_groups / total_groups, 0.99) - def test_int8_quantization_range(self): - """Test Int8 quantization stays in valid range""" - tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) - - self.assertEqual(tensor.qdata_int8.dtype, torch.int8) - self.assertTrue(torch.all(tensor.qdata_int8 >= -128)) - self.assertTrue(torch.all(tensor.qdata_int8 <= 127)) - def test_dequantization_accuracy(self): """Test dequantization error is reasonable""" - tensor = Int8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) dequantized = tensor.dequantize() # Apply same pruning to original for fair comparison @@ -79,14 +71,14 @@ def test_invalid_dimensions(self): invalid_weight = torch.randn(100, 100, dtype=self.dtype, device="cuda") with self.assertRaises(ValueError): - Int8SemiSparseTensor.from_hp(invalid_weight, [1, 100]) + Float8SemiSparseTensor.from_hp(invalid_weight, [1, 100]) def test_cpu_tensor_rejection(self): """Test CPU tensor is rejected""" cpu_weight = torch.randn(*self.shape, dtype=self.dtype) with self.assertRaises(ValueError): - Int8SemiSparseTensor.from_hp(cpu_weight, self.block_size) + Float8SemiSparseTensor.from_hp(cpu_weight, self.block_size) if __name__ == "__main__": diff --git a/torchao/prototype/quantization/quantize_/workflows/__init__.py b/torchao/prototype/quantization/quantize_/workflows/__init__.py index 97cf27ae9b..e0b4edf34a 100644 --- a/torchao/prototype/quantization/quantize_/workflows/__init__.py +++ b/torchao/prototype/quantization/quantize_/workflows/__init__.py @@ -1,11 +1,7 @@ from .float8.float8_semisparse_tensor import ( Float8SemiSparseTensor, ) -from .int8.int8_semisparse_tensor import ( - Int8SemiSparseTensor, -) __all__ = [ "Float8SemiSparseTensor", - "Int8SemiSparseTensor", ] diff --git a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py new file mode 100644 index 0000000000..64acd10f4e --- /dev/null +++ b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from torchao.quantization.quant_primitives import ( + _choose_qparams_affine_floatx, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] + +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + """ + float8 quantized tensor with 2:4 semi-structured sparsity layout + + Tensor Attributes: + qdata: float8 data in dense format + qdata_compressed: SparseSemiStructuredTensor (compressed format for matmul) + scale: scale factors for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + original_shape: original uncompressed shape + float8_dtype: float8 dtype variant + """ + + tensor_data_names = ["qdata", "qdata_compressed", "scale"] + tensor_attribute_names = ["block_size", "original_shape"] + optional_tensor_attribute_names = ["dtype"] + + def __new__( + cls: type, + qdata: torch.Tensor, + qdata_compressed: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + qdata_compressed: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + super().__init__() + self.qdata = qdata # dense fp8 for dequantization + self.qdata_compressed = qdata_compressed # compressed for matmul + self.scale = scale + self.block_size = block_size + self.original_shape = original_shape + self.float8_dtype = float8_dtype + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"qdata_shape={self.qdata.shape}, {self.scale=}, " + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + ) + + @property + def qdata_fp8(self): + """For test compatibility""" + return self.qdata + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: list[int], + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + if not w.is_cuda: + raise ValueError("Semi-sparse layout requires CUDA tensors") + + # Verify dimensions are compatible with 2:4 compression + rows, cols = w.shape + if rows % 32 != 0 or cols % 32 != 0: + raise ValueError( + "Tensor dimensions must be multiples of 32 for CUDA sparse compression" + ) + + # Validate block_size + if not all(bs > 0 for bs in block_size): + raise ValueError(f"block_size must be positive, got {block_size}") + + if rows % block_size[0] != 0 or cols % block_size[1] != 0: + raise ValueError( + f"Dimensions {w.shape} must be divisible by block_size {block_size}" + ) + + # Apply 2:4 sparsity pruning + with torch.no_grad(): + w_sparse = w.clone() + + pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] + w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) + + # Quantize to float8 + if float8_dtype == torch.float8_e4m3fn: + ebits, mbits = 4, 3 + max_val = 448.0 + elif float8_dtype == torch.float8_e5m2: + ebits, mbits = 5, 2 + max_val = 57344.0 + else: + raise ValueError(f"Unsupported float8 dtype: {float8_dtype}") + + scale = _choose_qparams_affine_floatx(w_sparse, ebits=ebits, mbits=mbits) + + if scale.isnan().any(): + raise ValueError("Scale contains NaN") + if not (scale > 0).all(): + raise ValueError(f"Scale contains non-positive values: min={scale.min()}") + + scale_expanded = scale.unsqueeze(1) + scaled_data = w_sparse / scale_expanded + scaled_data = scaled_data.clamp(-max_val, max_val) + fp8_data = scaled_data.to(float8_dtype).contiguous() + + if fp8_data.isnan().any(): + raise ValueError("fp8_data contains NaN after quantization") + + # Store fp8 data in both dense and compressed formats + fp8_data_fp16 = fp8_data.to(torch.float16) + from torch.sparse import to_sparse_semi_structured + + fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) + + return cls( + fp8_data, # dense for dequantization + fp8_compressed, # compressed for matmul + scale, + block_size, + original_shape=w.shape, + dtype=w.dtype, + float8_dtype=float8_dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + # Use dense fp8 data + qdata_fp = self.qdata.to(output_dtype) + scale_expanded = self.scale.view(-1, 1).to(output_dtype) + return qdata_fp * scale_expanded + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert isinstance(weight_tensor, Float8SemiSparseTensor) + if not isinstance(activation_tensor, Float8SemiSparseTensor): + raise TypeError( + "Float8SemiSparseTensor requires pre-quantized activations (static quantization only). " + "Activation must be Float8SemiSparseTensor." + ) + assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( + f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" + ) + + # Use compressed data for matmul + x_vals_dense = activation_tensor.qdata_compressed.to_dense() + x_vals_fp8 = x_vals_dense.view(torch.float8_e4m3fn) + x_scales = activation_tensor.scale + + w_vals_dense = weight_tensor.qdata_compressed.to_dense() + w_vals_fp8 = w_vals_dense.view(torch.float8_e4m3fn) + w_scales = weight_tensor.scale + + # Prepare activation for sparse matmul + tmp = x_vals_fp8 + if tmp.dim() > 2: + tmp = tmp.view(-1, tmp.shape[-1]) + row = tmp.shape[0] + + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) + + # Convert weight fp8 to fp16 with scale for matmul + w_scaled = w_vals_fp8.to(torch.float16) * w_scales.unsqueeze(1) + w_sparse_scaled = torch.sparse.to_sparse_semi_structured(w_scaled) + + # Matmul with sparse weight + y = torch.matmul( + tmp_padded.to(torch.bfloat16), w_sparse_scaled.t().to(torch.bfloat16) + ) + y = y[:row, :] + + # Apply activation scale + y = y * x_scales.unsqueeze(1) + + # Reshape to original activation shape + if x_vals_fp8.dim() > 2: + y = y.view(*x_vals_fp8.shape[:-1], y.shape[-1]) + + output_dtype = activation_tensor.dtype + y = y.to(output_dtype).contiguous() + + if bias is not None: + y += bias + return y + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Slice operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Slicing not supported for Float8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Select not supported for Float8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Float8SemiSparseTensor]) From 2d5f86d8eaf4e85615b608912996916d157f408b Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 2 Nov 2025 14:14:08 +0900 Subject: [PATCH 4/7] rename test file --- ...test_semisparse_tensor.py => test_float8_semisparse_tensor.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/prototype/quantization/quantize_/float8/{test_semisparse_tensor.py => test_float8_semisparse_tensor.py} (100%) diff --git a/test/prototype/quantization/quantize_/float8/test_semisparse_tensor.py b/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py similarity index 100% rename from test/prototype/quantization/quantize_/float8/test_semisparse_tensor.py rename to test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py From 4e7f4825358d1ea833e12520021d17bac0f2595d Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 2 Nov 2025 14:15:28 +0900 Subject: [PATCH 5/7] drop --- .../workflows/int8/int8_semisparse_tensor.py | 239 ------------------ 1 file changed, 239 deletions(-) delete mode 100644 torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py diff --git a/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py deleted file mode 100644 index f31713b191..0000000000 --- a/torchao/prototype/quantization/quantize_/workflows/int8/int8_semisparse_tensor.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch - -from torchao.quantization.quant_primitives import ( - MappingType, - _maybe_expand_scale_to_tensor_shape, - choose_qparams_affine, - quantize_affine, -) -from torchao.utils import TorchAOBaseTensor - -__all__ = ["Int8SemiSparseTensor"] - -aten = torch.ops.aten - - -class Int8SemiSparseTensor(TorchAOBaseTensor): - """ - int8 quantized tensor with 2:4 semi-structured sparsity layout - - Tensor Attributes: - qdata: SparseSemiStructuredTensor (compressed format, fp16) - scale: scale factors for dequantization - - Non-Tensor Attributes: - block_size: block size for quantization granularity - original_shape: original uncompressed shape - """ - - tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = ["block_size", "original_shape"] - optional_tensor_attribute_names = ["dtype"] - - def __new__( - cls: type, - qdata: torch.Tensor, - scale: torch.Tensor, - block_size: list[int], - original_shape: tuple[int, ...], - dtype=None, - ): - kwargs = { - "device": qdata.device, - "dtype": dtype or scale.dtype, - "requires_grad": False, - } - return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) - - def __init__( - self, - qdata: torch.Tensor, - scale: torch.Tensor, - block_size: list[int], - original_shape: tuple[int, ...], - dtype=None, - ): - super().__init__() - self.qdata = qdata - self.scale = scale - self.block_size = block_size - self.original_shape = original_shape - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"qdata_shape={self.qdata.shape}, {self.scale=}, " - f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" - ) - - @property - def qdata_int8(self): - return self.qdata.to_dense().to(torch.int8) - - @classmethod - def from_hp( - cls, - w: torch.Tensor, - block_size: list[int], - ): - if w.dim() != 2 or len(block_size) != 2: - raise ValueError("Expected 2D tensor and block_size length 2") - - if not w.is_cuda: - raise ValueError("Semi-sparse layout requires CUDA tensors") - - # Verify dimensions are compatible with 2:4 compression - rows, cols = w.shape - if rows % 32 != 0 or cols % 32 != 0: - raise ValueError( - "Tensor dimensions must be multiples of 32 for CUDA sparse compression" - ) - - # Validate block_size - if not all(bs > 0 for bs in block_size): - raise ValueError(f"block_size must be positive, got {block_size}") - - if rows % block_size[0] != 0 or cols % block_size[1] != 0: - raise ValueError( - f"Dimensions {w.shape} must be divisible by block_size {block_size}" - ) - - # Apply 2:4 sparsity pruning (row-wise for weight matrix) - with torch.no_grad(): - w_sparse = w.clone() - - pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] - w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) - - # Quantize the sparse weight - scale, zero_point = choose_qparams_affine( - input=w_sparse, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=torch.int8, - quant_min=-128, - quant_max=127, - scale_dtype=w.dtype, - zero_point_dtype=torch.int8, - ) - - int_data = quantize_affine( - w_sparse, - block_size=block_size, - scale=scale, - zero_point=zero_point, - output_dtype=torch.int8, - ).contiguous() - - int_data_fp16 = int_data.to(torch.float16) - from torch.sparse import to_sparse_semi_structured - - int_data_compressed = to_sparse_semi_structured(int_data_fp16) - - return cls( - int_data_compressed, - scale, - block_size, - original_shape=w.shape, - dtype=w.dtype, - ) - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - # Decompress and convert to int8 - qdata_dense = self.qdata.to_dense().to(torch.int8) - qdata_fp = qdata_dense.to(output_dtype) - - scale_expanded = _maybe_expand_scale_to_tensor_shape( - self.scale, self.original_shape - ) - return qdata_fp * scale_expanded.to(output_dtype) - - -implements = Int8SemiSparseTensor.implements -implements_torch_function = Int8SemiSparseTensor.implements_torch_function - - -@implements(aten.linear.default) -@implements_torch_function(torch.nn.functional.linear) -def _(func, types, args, kwargs): - activation_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - - assert isinstance(weight_tensor, Int8SemiSparseTensor) - assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( - f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" - ) - - # Dynamic quantization: Activation scale in runtime - original_shape = activation_tensor.shape - x_flat = activation_tensor.view(-1, activation_tensor.shape[-1]) - x_scales = x_flat.abs().max(dim=-1, keepdim=True)[0] / 127.0 - x_scales = x_scales.clamp(min=1e-5) - x_vals = (x_flat / x_scales).round().clamp(-128, 127).to(torch.int8) - - # Weight quantization after activation quantization - w_vals_dense = weight_tensor.qdata.to_dense().to(torch.int8) - w_scales = weight_tensor.scale - - # Prepare activation for sparse matmul - row, col = x_vals.shape - - from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - - tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(x_vals) - - # Perform sparse matmul with int8, output as bfloat16 with weight scale - w_scaled = w_vals_dense.to(torch.float16) * w_scales.view(-1, 1) - w_sparse_scaled = torch.sparse.to_sparse_semi_structured(w_scaled) - y_bf16 = torch.matmul( - tmp_padded.to(torch.bfloat16), w_sparse_scaled.t().to(torch.bfloat16) - ) - y_bf16 = y_bf16[:row, :] - - # Apply activation scale - y = (y_bf16 * x_scales).view(*original_shape[:-1], y_bf16.shape[-1]) - - output_dtype = activation_tensor.dtype - y = y.to(output_dtype).contiguous() - - if bias is not None: - y += bias - return y - - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - """Slice operation - not supported for compressed sparse format""" - # TODO: Build this tensor utility operation - raise NotImplementedError( - "Slicing not supported for Int8SemiSparseTensor. " - "Decompress first using dequantize() if needed." - ) - - -@implements(aten.select.int) -def _(func, types, args, kwargs): - """Select operation - not supported for compressed sparse format""" - # TODO: Build this tensor utility operation - raise NotImplementedError( - "Select not supported for Int8SemiSparseTensor. " - "Decompress first using dequantize() if needed." - ) - - -Int8SemiSparseTensor.__module__ = "torchao.quantization" -torch.serialization.add_safe_globals([Int8SemiSparseTensor]) From 96f837443da856a98141cc987e85857751610fc0 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 2 Nov 2025 18:14:09 +0900 Subject: [PATCH 6/7] update api and test --- .../float8/test_float8_semisparse_tensor.py | 84 ++++++++++++++-- .../float8/float8_semisparse_tensor.py | 96 +++++++++---------- 2 files changed, 118 insertions(+), 62 deletions(-) diff --git a/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py b/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py index f3c2b673b8..b71006f380 100644 --- a/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py +++ b/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py @@ -41,7 +41,7 @@ def test_sparsity_pattern(self): tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) dequantized = tensor.dequantize() - # Check 2:4 pattern (skip overall sparsity check for compressed format) + # Check 2:4 pattern reshaped = dequantized.reshape(-1, 4) zeros_per_group = (reshaped == 0).sum(dim=1) valid_groups = (zeros_per_group == 2).sum().item() @@ -59,11 +59,19 @@ def test_dequantization_accuracy(self): pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) - error = (dequantized - w_sparse).abs().max() - rel_error = error / w_sparse.abs().max() + # Norm-based metrics for numerical stability + error = (dequantized - w_sparse).abs() + max_error = error.max() + mean_error = error.mean() - self.assertLess(error.item(), 2.5) - self.assertLess(rel_error.item(), 0.5) + # Relative error using non-zero elements + non_zero_mask = w_sparse != 0 + if non_zero_mask.any(): + rel_error = (error[non_zero_mask] / w_sparse[non_zero_mask].abs()).mean() + self.assertLess(rel_error.item(), 0.1) + + self.assertLess(max_error.item(), 1.0) + self.assertLess(mean_error.item(), 0.1) def test_invalid_dimensions(self): """Test dimension validation""" @@ -73,13 +81,71 @@ def test_invalid_dimensions(self): with self.assertRaises(ValueError): Float8SemiSparseTensor.from_hp(invalid_weight, [1, 100]) - def test_cpu_tensor_rejection(self): - """Test CPU tensor is rejected""" - cpu_weight = torch.randn(*self.shape, dtype=self.dtype) - + def test_device(self): + """Test if device handler work""" + # CPU tensor should be rejected + cpu_weight = torch.randn(*self.shape, dtype=self.dtype, device="cpu") with self.assertRaises(ValueError): Float8SemiSparseTensor.from_hp(cpu_weight, self.block_size) + # CUDA tensor components should all be on CUDA + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + self.assertEqual(tensor.qdata.device, tensor.qdata_compressed.device) + self.assertEqual(tensor.qdata.device, tensor.scale.device) + self.assertTrue(tensor.qdata.is_cuda) + + def test_w8a8_dynamic_activation(self): + """Test W8A8-FP-CSR with dynamic activation quantization""" + weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + + batch_size = 32 + in_features = self.shape[1] + activation = torch.randn( + batch_size, in_features, dtype=self.dtype, device="cuda" + ) + output = torch.nn.functional.linear(activation, weight_tensor) + + expected_shape = (batch_size, self.shape[0]) + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(output.isnan().any()) + self.assertFalse(output.isinf().any()) + + def test_linear_with_bias(self): + """Test linear operation with bias""" + weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + activation = torch.randn(32, self.shape[1], dtype=self.dtype, device="cuda") + bias = torch.randn(self.shape[0], dtype=self.dtype, device="cuda") + + output = torch.nn.functional.linear(activation, weight_tensor, bias) + + self.assertEqual(output.shape, (32, self.shape[0])) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(output.isnan().any()) + self.assertFalse(output.isinf().any()) + + def test_batched_input(self): + """Test 3D batched input""" + weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + batch_dims = (4, 8) + activation = torch.randn( + *batch_dims, self.shape[1], dtype=self.dtype, device="cuda" + ) + + output = torch.nn.functional.linear(activation, weight_tensor) + + expected_shape = (*batch_dims, self.shape[0]) + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(output.isnan().any()) + + def test_zero_weight_validation(self): + """Test scale validation with zero weights""" + zero_weight = torch.zeros(*self.shape, dtype=self.dtype, device="cuda") + + with self.assertRaises(ValueError): + Float8SemiSparseTensor.from_hp(zero_weight, self.block_size) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py index 64acd10f4e..8a015a8b4a 100644 --- a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py +++ b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py @@ -7,6 +7,7 @@ from typing import Optional import torch +from torch.sparse import SparseSemiStructuredTensorCUSPARSELT, to_sparse_semi_structured from torchao.quantization.quant_primitives import ( _choose_qparams_affine_floatx, @@ -20,7 +21,7 @@ class Float8SemiSparseTensor(TorchAOBaseTensor): """ - float8 quantized tensor with 2:4 semi-structured sparsity layout + W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout Tensor Attributes: qdata: float8 data in dense format @@ -120,30 +121,24 @@ def from_hp( pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) - # Quantize to float8 - if float8_dtype == torch.float8_e4m3fn: - ebits, mbits = 4, 3 - max_val = 448.0 - elif float8_dtype == torch.float8_e5m2: - ebits, mbits = 5, 2 - max_val = 57344.0 - else: - raise ValueError(f"Unsupported float8 dtype: {float8_dtype}") + # Check for all-zero (sparsity=1) tensor + if w_sparse.abs().max() == 0: + raise ValueError("Input tensor is all zeros after pruning") - scale = _choose_qparams_affine_floatx(w_sparse, ebits=ebits, mbits=mbits) + scale = _choose_qparams_affine_floatx(w_sparse, ebits=4, mbits=3) - if scale.isnan().any(): - raise ValueError("Scale contains NaN") + if not torch.isfinite(scale).all(): + raise ValueError("Scale contains NaN or inf values") if not (scale > 0).all(): raise ValueError(f"Scale contains non-positive values: min={scale.min()}") scale_expanded = scale.unsqueeze(1) scaled_data = w_sparse / scale_expanded - scaled_data = scaled_data.clamp(-max_val, max_val) + scaled_data = scaled_data.clamp(-448.0, 448.0) fp8_data = scaled_data.to(float8_dtype).contiguous() - if fp8_data.isnan().any(): - raise ValueError("fp8_data contains NaN after quantization") + if not torch.isfinite(fp8_data.to(torch.float32)).all(): + raise ValueError("fp8_data contains NaN/Inf after quantization") # Store fp8 data in both dense and compressed formats fp8_data_fp16 = fp8_data.to(torch.float16) @@ -185,53 +180,48 @@ def _(func, types, args, kwargs): ) assert isinstance(weight_tensor, Float8SemiSparseTensor) - if not isinstance(activation_tensor, Float8SemiSparseTensor): - raise TypeError( - "Float8SemiSparseTensor requires pre-quantized activations (static quantization only). " - "Activation must be Float8SemiSparseTensor." - ) assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" ) - # Use compressed data for matmul - x_vals_dense = activation_tensor.qdata_compressed.to_dense() - x_vals_fp8 = x_vals_dense.view(torch.float8_e4m3fn) - x_scales = activation_tensor.scale - - w_vals_dense = weight_tensor.qdata_compressed.to_dense() - w_vals_fp8 = w_vals_dense.view(torch.float8_e4m3fn) - w_scales = weight_tensor.scale - - # Prepare activation for sparse matmul - tmp = x_vals_fp8 - if tmp.dim() > 2: - tmp = tmp.view(-1, tmp.shape[-1]) - row = tmp.shape[0] - - from torch.sparse import SparseSemiStructuredTensorCUSPARSELT - - tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) + # Flatten batch dimensions for scale computation + orig_shape = activation_tensor.shape + if activation_tensor.dim() > 2: + activation_flat = activation_tensor.view(-1, orig_shape[-1]) + else: + activation_flat = activation_tensor + + # Compute dynamic scale for activation quantization + x_scales = _choose_qparams_affine_floatx(activation_flat, ebits=4, mbits=3) + x_scales = x_scales.unsqueeze(1) # [batch, 1] + + # Quantize activation + scaled_x = activation_flat / x_scales + scaled_x = scaled_x.clamp(-448.0, 448.0) + x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn) + + # Dequantize both activation and weight before MatMul to avoid FP16 overflow + x_dequant = (x_vals_fp8.to(torch.float32) * x_scales.to(torch.float32)).to( + torch.float16 + ) + w_dequant = ( + weight_tensor.qdata.to(torch.float32) + * weight_tensor.scale.unsqueeze(1).to(torch.float32) + ).to(torch.float16) - # Convert weight fp8 to fp16 with scale for matmul - w_scaled = w_vals_fp8.to(torch.float16) * w_scales.unsqueeze(1) - w_sparse_scaled = torch.sparse.to_sparse_semi_structured(w_scaled) + # Sparse MatMul with dequntized tensor + w_sparse = to_sparse_semi_structured(w_dequant) + row = x_dequant.shape[0] + x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(x_dequant) - # Matmul with sparse weight - y = torch.matmul( - tmp_padded.to(torch.bfloat16), w_sparse_scaled.t().to(torch.bfloat16) - ) + y = torch.matmul(x_padded, w_sparse.t()) y = y[:row, :] - # Apply activation scale - y = y * x_scales.unsqueeze(1) - # Reshape to original activation shape - if x_vals_fp8.dim() > 2: - y = y.view(*x_vals_fp8.shape[:-1], y.shape[-1]) + if activation_tensor.dim() > 2: + y = y.view(*orig_shape[:-1], -1) - output_dtype = activation_tensor.dtype - y = y.to(output_dtype).contiguous() + y = y.to(activation_tensor.dtype).contiguous() if bias is not None: y += bias From f5f7a1717521b2a711602cab19640fec0dfe7700 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 2 Nov 2025 18:43:41 +0900 Subject: [PATCH 7/7] fix FP16 bit-range overflow --- .../float8/float8_semisparse_tensor.py | 52 ++++++------------- 1 file changed, 16 insertions(+), 36 deletions(-) diff --git a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py index 8a015a8b4a..1d5d988bca 100644 --- a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py +++ b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py @@ -142,7 +142,6 @@ def from_hp( # Store fp8 data in both dense and compressed formats fp8_data_fp16 = fp8_data.to(torch.float16) - from torch.sparse import to_sparse_semi_structured fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) @@ -180,47 +179,28 @@ def _(func, types, args, kwargs): ) assert isinstance(weight_tensor, Float8SemiSparseTensor) - assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], ( - f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}" - ) - - # Flatten batch dimensions for scale computation - orig_shape = activation_tensor.shape - if activation_tensor.dim() > 2: - activation_flat = activation_tensor.view(-1, orig_shape[-1]) - else: - activation_flat = activation_tensor + assert activation_tensor.dim() == 2, "Only 2D input supported" + assert activation_tensor.shape[-1] == weight_tensor.original_shape[1] - # Compute dynamic scale for activation quantization - x_scales = _choose_qparams_affine_floatx(activation_flat, ebits=4, mbits=3) - x_scales = x_scales.unsqueeze(1) # [batch, 1] + x_scales = _choose_qparams_affine_floatx(activation_tensor, ebits=4, mbits=3) + w_scales = weight_tensor.scale - # Quantize activation - scaled_x = activation_flat / x_scales - scaled_x = scaled_x.clamp(-448.0, 448.0) + # Global normalizer to prevent overflow + global_scale = (x_scales.max() * w_scales.max()).sqrt().clamp(min=0.01) + x_scales_adj = (x_scales.unsqueeze(1) / global_scale).to(torch.float32) + scaled_x = (activation_tensor.to(torch.float32) / x_scales_adj).clamp(-448.0, 448.0) x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn) - # Dequantize both activation and weight before MatMul to avoid FP16 overflow - x_dequant = (x_vals_fp8.to(torch.float32) * x_scales.to(torch.float32)).to( - torch.float16 + # MatMul + x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input( + x_vals_fp8.to(torch.float16) ) - w_dequant = ( - weight_tensor.qdata.to(torch.float32) - * weight_tensor.scale.unsqueeze(1).to(torch.float32) - ).to(torch.float16) - - # Sparse MatMul with dequntized tensor - w_sparse = to_sparse_semi_structured(w_dequant) - row = x_dequant.shape[0] - x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(x_dequant) - - y = torch.matmul(x_padded, w_sparse.t()) - y = y[:row, :] - - # Reshape to original activation shape - if activation_tensor.dim() > 2: - y = y.view(*orig_shape[:-1], -1) + y_fp16 = torch.matmul(x_padded, weight_tensor.qdata_compressed.t()) + y = y_fp16[: activation_tensor.shape[0], :].to(torch.float32) + # Restore scale + w_scales_fp32 = w_scales.to(torch.float32) + y = y * (x_scales_adj * w_scales_fp32.unsqueeze(0) * global_scale) y = y.to(activation_tensor.dtype).contiguous() if bias is not None: