diff --git a/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py b/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py new file mode 100644 index 0000000000..b71006f380 --- /dev/null +++ b/test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal import common_utils + +from torchao.prototype.quantization.quantize_.workflows.float8.float8_semisparse_tensor import ( + Float8SemiSparseTensor, +) +from torchao.testing.utils import TorchAOIntegrationTestCase + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests +class TestFloat8SemiSparseTensor(TorchAOIntegrationTestCase): + def setUp(self): + super().setUp() + torch.manual_seed(42) + + # Use 512x512 for 2:4 compatibility (multiples of 32) + self.shape = (512, 512) + self.dtype = torch.bfloat16 + self.block_size = [1, 512] + self.weight_fp = torch.randn(*self.shape, dtype=self.dtype, device="cuda") + + def test_creation_and_shape(self): + """Test tensor creation and shape preservation""" + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.shape, self.shape) + self.assertEqual(tensor.original_shape, self.shape) + self.assertEqual(tensor.scale.shape[0], self.shape[0]) + + def test_sparsity_pattern(self): + """Test 2:4 sparsity pattern is maintained""" + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + dequantized = tensor.dequantize() + + # Check 2:4 pattern + reshaped = dequantized.reshape(-1, 4) + zeros_per_group = (reshaped == 0).sum(dim=1) + valid_groups = (zeros_per_group == 2).sum().item() + total_groups = zeros_per_group.numel() + + self.assertGreaterEqual(valid_groups / total_groups, 0.99) + + def test_dequantization_accuracy(self): + """Test dequantization error is reasonable""" + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + dequantized = tensor.dequantize() + + # Apply same pruning to original for fair comparison + w_sparse = self.weight_fp.detach().clone() + pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] + w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) + + # Norm-based metrics for numerical stability + error = (dequantized - w_sparse).abs() + max_error = error.max() + mean_error = error.mean() + + # Relative error using non-zero elements + non_zero_mask = w_sparse != 0 + if non_zero_mask.any(): + rel_error = (error[non_zero_mask] / w_sparse[non_zero_mask].abs()).mean() + self.assertLess(rel_error.item(), 0.1) + + self.assertLess(max_error.item(), 1.0) + self.assertLess(mean_error.item(), 0.1) + + def test_invalid_dimensions(self): + """Test dimension validation""" + # Not multiple of 32 + invalid_weight = torch.randn(100, 100, dtype=self.dtype, device="cuda") + + with self.assertRaises(ValueError): + Float8SemiSparseTensor.from_hp(invalid_weight, [1, 100]) + + def test_device(self): + """Test if device handler work""" + # CPU tensor should be rejected + cpu_weight = torch.randn(*self.shape, dtype=self.dtype, device="cpu") + with self.assertRaises(ValueError): + Float8SemiSparseTensor.from_hp(cpu_weight, self.block_size) + + # CUDA tensor components should all be on CUDA + tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + self.assertEqual(tensor.qdata.device, tensor.qdata_compressed.device) + self.assertEqual(tensor.qdata.device, tensor.scale.device) + self.assertTrue(tensor.qdata.is_cuda) + + def test_w8a8_dynamic_activation(self): + """Test W8A8-FP-CSR with dynamic activation quantization""" + weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + + batch_size = 32 + in_features = self.shape[1] + activation = torch.randn( + batch_size, in_features, dtype=self.dtype, device="cuda" + ) + output = torch.nn.functional.linear(activation, weight_tensor) + + expected_shape = (batch_size, self.shape[0]) + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(output.isnan().any()) + self.assertFalse(output.isinf().any()) + + def test_linear_with_bias(self): + """Test linear operation with bias""" + weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + activation = torch.randn(32, self.shape[1], dtype=self.dtype, device="cuda") + bias = torch.randn(self.shape[0], dtype=self.dtype, device="cuda") + + output = torch.nn.functional.linear(activation, weight_tensor, bias) + + self.assertEqual(output.shape, (32, self.shape[0])) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(output.isnan().any()) + self.assertFalse(output.isinf().any()) + + def test_batched_input(self): + """Test 3D batched input""" + weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) + batch_dims = (4, 8) + activation = torch.randn( + *batch_dims, self.shape[1], dtype=self.dtype, device="cuda" + ) + + output = torch.nn.functional.linear(activation, weight_tensor) + + expected_shape = (*batch_dims, self.shape[0]) + self.assertEqual(output.shape, expected_shape) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(output.isnan().any()) + + def test_zero_weight_validation(self): + """Test scale validation with zero weights""" + zero_weight = torch.zeros(*self.shape, dtype=self.dtype, device="cuda") + + with self.assertRaises(ValueError): + Float8SemiSparseTensor.from_hp(zero_weight, self.block_size) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torchao/prototype/quantization/quantize_/workflows/__init__.py b/torchao/prototype/quantization/quantize_/workflows/__init__.py new file mode 100644 index 0000000000..e0b4edf34a --- /dev/null +++ b/torchao/prototype/quantization/quantize_/workflows/__init__.py @@ -0,0 +1,7 @@ +from .float8.float8_semisparse_tensor import ( + Float8SemiSparseTensor, +) + +__all__ = [ + "Float8SemiSparseTensor", +] diff --git a/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py new file mode 100644 index 0000000000..1d5d988bca --- /dev/null +++ b/torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch.sparse import SparseSemiStructuredTensorCUSPARSELT, to_sparse_semi_structured + +from torchao.quantization.quant_primitives import ( + _choose_qparams_affine_floatx, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] + +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + """ + W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout + + Tensor Attributes: + qdata: float8 data in dense format + qdata_compressed: SparseSemiStructuredTensor (compressed format for matmul) + scale: scale factors for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + original_shape: original uncompressed shape + float8_dtype: float8 dtype variant + """ + + tensor_data_names = ["qdata", "qdata_compressed", "scale"] + tensor_attribute_names = ["block_size", "original_shape"] + optional_tensor_attribute_names = ["dtype"] + + def __new__( + cls: type, + qdata: torch.Tensor, + qdata_compressed: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + qdata_compressed: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + original_shape: tuple[int, ...], + dtype=None, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + super().__init__() + self.qdata = qdata # dense fp8 for dequantization + self.qdata_compressed = qdata_compressed # compressed for matmul + self.scale = scale + self.block_size = block_size + self.original_shape = original_shape + self.float8_dtype = float8_dtype + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"qdata_shape={self.qdata.shape}, {self.scale=}, " + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + ) + + @property + def qdata_fp8(self): + """For test compatibility""" + return self.qdata + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: list[int], + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + if not w.is_cuda: + raise ValueError("Semi-sparse layout requires CUDA tensors") + + # Verify dimensions are compatible with 2:4 compression + rows, cols = w.shape + if rows % 32 != 0 or cols % 32 != 0: + raise ValueError( + "Tensor dimensions must be multiples of 32 for CUDA sparse compression" + ) + + # Validate block_size + if not all(bs > 0 for bs in block_size): + raise ValueError(f"block_size must be positive, got {block_size}") + + if rows % block_size[0] != 0 or cols % block_size[1] != 0: + raise ValueError( + f"Dimensions {w.shape} must be divisible by block_size {block_size}" + ) + + # Apply 2:4 sparsity pruning + with torch.no_grad(): + w_sparse = w.clone() + + pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] + w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) + + # Check for all-zero (sparsity=1) tensor + if w_sparse.abs().max() == 0: + raise ValueError("Input tensor is all zeros after pruning") + + scale = _choose_qparams_affine_floatx(w_sparse, ebits=4, mbits=3) + + if not torch.isfinite(scale).all(): + raise ValueError("Scale contains NaN or inf values") + if not (scale > 0).all(): + raise ValueError(f"Scale contains non-positive values: min={scale.min()}") + + scale_expanded = scale.unsqueeze(1) + scaled_data = w_sparse / scale_expanded + scaled_data = scaled_data.clamp(-448.0, 448.0) + fp8_data = scaled_data.to(float8_dtype).contiguous() + + if not torch.isfinite(fp8_data.to(torch.float32)).all(): + raise ValueError("fp8_data contains NaN/Inf after quantization") + + # Store fp8 data in both dense and compressed formats + fp8_data_fp16 = fp8_data.to(torch.float16) + + fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) + + return cls( + fp8_data, # dense for dequantization + fp8_compressed, # compressed for matmul + scale, + block_size, + original_shape=w.shape, + dtype=w.dtype, + float8_dtype=float8_dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + # Use dense fp8 data + qdata_fp = self.qdata.to(output_dtype) + scale_expanded = self.scale.view(-1, 1).to(output_dtype) + return qdata_fp * scale_expanded + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert isinstance(weight_tensor, Float8SemiSparseTensor) + assert activation_tensor.dim() == 2, "Only 2D input supported" + assert activation_tensor.shape[-1] == weight_tensor.original_shape[1] + + x_scales = _choose_qparams_affine_floatx(activation_tensor, ebits=4, mbits=3) + w_scales = weight_tensor.scale + + # Global normalizer to prevent overflow + global_scale = (x_scales.max() * w_scales.max()).sqrt().clamp(min=0.01) + x_scales_adj = (x_scales.unsqueeze(1) / global_scale).to(torch.float32) + scaled_x = (activation_tensor.to(torch.float32) / x_scales_adj).clamp(-448.0, 448.0) + x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn) + + # MatMul + x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input( + x_vals_fp8.to(torch.float16) + ) + y_fp16 = torch.matmul(x_padded, weight_tensor.qdata_compressed.t()) + y = y_fp16[: activation_tensor.shape[0], :].to(torch.float32) + + # Restore scale + w_scales_fp32 = w_scales.to(torch.float32) + y = y * (x_scales_adj * w_scales_fp32.unsqueeze(0) * global_scale) + y = y.to(activation_tensor.dtype).contiguous() + + if bias is not None: + y += bias + return y + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Slice operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Slicing not supported for Float8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation - not supported for compressed sparse format""" + # TODO: Build this tensor utility operation + raise NotImplementedError( + "Select not supported for Float8SemiSparseTensor. " + "Decompress first using dequantize() if needed." + ) + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Float8SemiSparseTensor])