-
Notifications
You must be signed in to change notification settings - Fork 360
Introduce new W8A8-FP-CSR quantitzation API #3258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3b28e18
9a48316
ea18e89
2d5f86d
4e7f482
96f8374
f5f7a17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal import common_utils | ||
|
|
||
| from torchao.prototype.quantization.quantize_.workflows.float8.float8_semisparse_tensor import ( | ||
| Float8SemiSparseTensor, | ||
| ) | ||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @common_utils.instantiate_parametrized_tests | ||
| class TestFloat8SemiSparseTensor(TorchAOIntegrationTestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
| torch.manual_seed(42) | ||
|
|
||
| # Use 512x512 for 2:4 compatibility (multiples of 32) | ||
| self.shape = (512, 512) | ||
| self.dtype = torch.bfloat16 | ||
| self.block_size = [1, 512] | ||
| self.weight_fp = torch.randn(*self.shape, dtype=self.dtype, device="cuda") | ||
|
|
||
| def test_creation_and_shape(self): | ||
| """Test tensor creation and shape preservation""" | ||
| tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
|
|
||
| self.assertEqual(tensor.shape, self.shape) | ||
| self.assertEqual(tensor.original_shape, self.shape) | ||
| self.assertEqual(tensor.scale.shape[0], self.shape[0]) | ||
|
|
||
| def test_sparsity_pattern(self): | ||
| """Test 2:4 sparsity pattern is maintained""" | ||
| tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
| dequantized = tensor.dequantize() | ||
|
|
||
| # Check 2:4 pattern | ||
| reshaped = dequantized.reshape(-1, 4) | ||
| zeros_per_group = (reshaped == 0).sum(dim=1) | ||
| valid_groups = (zeros_per_group == 2).sum().item() | ||
| total_groups = zeros_per_group.numel() | ||
|
|
||
| self.assertGreaterEqual(valid_groups / total_groups, 0.99) | ||
|
|
||
| def test_dequantization_accuracy(self): | ||
| """Test dequantization error is reasonable""" | ||
| tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
| dequantized = tensor.dequantize() | ||
|
|
||
| # Apply same pruning to original for fair comparison | ||
| w_sparse = self.weight_fp.detach().clone() | ||
| pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] | ||
| w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) | ||
|
|
||
| # Norm-based metrics for numerical stability | ||
| error = (dequantized - w_sparse).abs() | ||
| max_error = error.max() | ||
| mean_error = error.mean() | ||
|
|
||
| # Relative error using non-zero elements | ||
| non_zero_mask = w_sparse != 0 | ||
| if non_zero_mask.any(): | ||
| rel_error = (error[non_zero_mask] / w_sparse[non_zero_mask].abs()).mean() | ||
| self.assertLess(rel_error.item(), 0.1) | ||
|
|
||
| self.assertLess(max_error.item(), 1.0) | ||
| self.assertLess(mean_error.item(), 0.1) | ||
|
|
||
| def test_invalid_dimensions(self): | ||
| """Test dimension validation""" | ||
| # Not multiple of 32 | ||
| invalid_weight = torch.randn(100, 100, dtype=self.dtype, device="cuda") | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| Float8SemiSparseTensor.from_hp(invalid_weight, [1, 100]) | ||
|
|
||
| def test_device(self): | ||
| """Test if device handler work""" | ||
| # CPU tensor should be rejected | ||
| cpu_weight = torch.randn(*self.shape, dtype=self.dtype, device="cpu") | ||
| with self.assertRaises(ValueError): | ||
| Float8SemiSparseTensor.from_hp(cpu_weight, self.block_size) | ||
|
|
||
| # CUDA tensor components should all be on CUDA | ||
| tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
| self.assertEqual(tensor.qdata.device, tensor.qdata_compressed.device) | ||
| self.assertEqual(tensor.qdata.device, tensor.scale.device) | ||
| self.assertTrue(tensor.qdata.is_cuda) | ||
|
|
||
| def test_w8a8_dynamic_activation(self): | ||
| """Test W8A8-FP-CSR with dynamic activation quantization""" | ||
| weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
|
|
||
| batch_size = 32 | ||
| in_features = self.shape[1] | ||
| activation = torch.randn( | ||
| batch_size, in_features, dtype=self.dtype, device="cuda" | ||
| ) | ||
| output = torch.nn.functional.linear(activation, weight_tensor) | ||
|
|
||
| expected_shape = (batch_size, self.shape[0]) | ||
| self.assertEqual(output.shape, expected_shape) | ||
| self.assertEqual(output.dtype, self.dtype) | ||
| self.assertFalse(output.isnan().any()) | ||
| self.assertFalse(output.isinf().any()) | ||
|
|
||
| def test_linear_with_bias(self): | ||
| """Test linear operation with bias""" | ||
| weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
| activation = torch.randn(32, self.shape[1], dtype=self.dtype, device="cuda") | ||
| bias = torch.randn(self.shape[0], dtype=self.dtype, device="cuda") | ||
|
|
||
| output = torch.nn.functional.linear(activation, weight_tensor, bias) | ||
|
|
||
| self.assertEqual(output.shape, (32, self.shape[0])) | ||
| self.assertEqual(output.dtype, self.dtype) | ||
| self.assertFalse(output.isnan().any()) | ||
| self.assertFalse(output.isinf().any()) | ||
|
|
||
| def test_batched_input(self): | ||
| """Test 3D batched input""" | ||
| weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size) | ||
| batch_dims = (4, 8) | ||
| activation = torch.randn( | ||
| *batch_dims, self.shape[1], dtype=self.dtype, device="cuda" | ||
| ) | ||
|
|
||
| output = torch.nn.functional.linear(activation, weight_tensor) | ||
|
|
||
| expected_shape = (*batch_dims, self.shape[0]) | ||
| self.assertEqual(output.shape, expected_shape) | ||
| self.assertEqual(output.dtype, self.dtype) | ||
| self.assertFalse(output.isnan().any()) | ||
|
|
||
| def test_zero_weight_validation(self): | ||
| """Test scale validation with zero weights""" | ||
| zero_weight = torch.zeros(*self.shape, dtype=self.dtype, device="cuda") | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| Float8SemiSparseTensor.from_hp(zero_weight, self.block_size) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| common_utils.run_tests() |
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete? we don't want this to be in prototype I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be add to the init file without the prototype in path |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| from .float8.float8_semisparse_tensor import ( | ||
| Float8SemiSparseTensor, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "Float8SemiSparseTensor", | ||
| ] |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,232 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the BSD 3-Clause license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|
|
||||
| from typing import Optional | ||||
|
|
||||
| import torch | ||||
| from torch.sparse import SparseSemiStructuredTensorCUSPARSELT, to_sparse_semi_structured | ||||
|
|
||||
| from torchao.quantization.quant_primitives import ( | ||||
| _choose_qparams_affine_floatx, | ||||
| ) | ||||
| from torchao.utils import TorchAOBaseTensor | ||||
|
|
||||
| __all__ = ["Float8SemiSparseTensor"] | ||||
|
|
||||
| aten = torch.ops.aten | ||||
|
|
||||
|
|
||||
| class Float8SemiSparseTensor(TorchAOBaseTensor): | ||||
| """ | ||||
| W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: comment looks wrong, CSR is compressed sparse row and it's not the sparse format used here (2:4 sparsity) |
||||
|
|
||||
| Tensor Attributes: | ||||
| qdata: float8 data in dense format | ||||
| qdata_compressed: SparseSemiStructuredTensor (compressed format for matmul) | ||||
| scale: scale factors for dequantization | ||||
|
|
||||
| Non-Tensor Attributes: | ||||
| block_size: block size for quantization granularity | ||||
| original_shape: original uncompressed shape | ||||
| float8_dtype: float8 dtype variant | ||||
| """ | ||||
|
|
||||
| tensor_data_names = ["qdata", "qdata_compressed", "scale"] | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think
|
||||
| tensor_attribute_names = ["block_size", "original_shape"] | ||||
| optional_tensor_attribute_names = ["dtype"] | ||||
|
|
||||
| def __new__( | ||||
| cls: type, | ||||
| qdata: torch.Tensor, | ||||
| qdata_compressed: torch.Tensor, | ||||
| scale: torch.Tensor, | ||||
| block_size: list[int], | ||||
| original_shape: tuple[int, ...], | ||||
| dtype=None, | ||||
| float8_dtype: torch.dtype = torch.float8_e4m3fn, | ||||
| ): | ||||
| kwargs = { | ||||
| "device": qdata.device, | ||||
| "dtype": dtype or scale.dtype, | ||||
| "requires_grad": False, | ||||
| } | ||||
| return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) | ||||
|
|
||||
| def __init__( | ||||
| self, | ||||
| qdata: torch.Tensor, | ||||
| qdata_compressed: torch.Tensor, | ||||
| scale: torch.Tensor, | ||||
| block_size: list[int], | ||||
| original_shape: tuple[int, ...], | ||||
| dtype=None, | ||||
| float8_dtype: torch.dtype = torch.float8_e4m3fn, | ||||
| ): | ||||
| super().__init__() | ||||
| self.qdata = qdata # dense fp8 for dequantization | ||||
| self.qdata_compressed = qdata_compressed # compressed for matmul | ||||
| self.scale = scale | ||||
| self.block_size = block_size | ||||
| self.original_shape = original_shape | ||||
| self.float8_dtype = float8_dtype | ||||
|
|
||||
| def __repr__(self): | ||||
| return ( | ||||
| f"{self.__class__.__name__}(" | ||||
| f"qdata_shape={self.qdata.shape}, {self.scale=}, " | ||||
| f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" | ||||
| ) | ||||
|
|
||||
| @property | ||||
| def qdata_fp8(self): | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this? |
||||
| """For test compatibility""" | ||||
| return self.qdata | ||||
|
|
||||
| @classmethod | ||||
| def from_hp( | ||||
| cls, | ||||
| w: torch.Tensor, | ||||
| block_size: list[int], | ||||
| float8_dtype: torch.dtype = torch.float8_e4m3fn, | ||||
| ): | ||||
| if w.dim() != 2 or len(block_size) != 2: | ||||
| raise ValueError("Expected 2D tensor and block_size length 2") | ||||
|
|
||||
| if not w.is_cuda: | ||||
| raise ValueError("Semi-sparse layout requires CUDA tensors") | ||||
|
|
||||
| # Verify dimensions are compatible with 2:4 compression | ||||
| rows, cols = w.shape | ||||
| if rows % 32 != 0 or cols % 32 != 0: | ||||
| raise ValueError( | ||||
| "Tensor dimensions must be multiples of 32 for CUDA sparse compression" | ||||
| ) | ||||
|
|
||||
| # Validate block_size | ||||
| if not all(bs > 0 for bs in block_size): | ||||
| raise ValueError(f"block_size must be positive, got {block_size}") | ||||
|
|
||||
| if rows % block_size[0] != 0 or cols % block_size[1] != 0: | ||||
| raise ValueError( | ||||
| f"Dimensions {w.shape} must be divisible by block_size {block_size}" | ||||
| ) | ||||
|
|
||||
| # Apply 2:4 sparsity pruning | ||||
| with torch.no_grad(): | ||||
| w_sparse = w.clone() | ||||
|
|
||||
| pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use this util: Line 101 in 315e9b4
|
||||
| w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) | ||||
|
|
||||
| # Check for all-zero (sparsity=1) tensor | ||||
| if w_sparse.abs().max() == 0: | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be supported actually? I don't see why we should error here. |
||||
| raise ValueError("Input tensor is all zeros after pruning") | ||||
|
|
||||
| scale = _choose_qparams_affine_floatx(w_sparse, ebits=4, mbits=3) | ||||
|
|
||||
| if not torch.isfinite(scale).all(): | ||||
| raise ValueError("Scale contains NaN or inf values") | ||||
| if not (scale > 0).all(): | ||||
| raise ValueError(f"Scale contains non-positive values: min={scale.min()}") | ||||
|
|
||||
| scale_expanded = scale.unsqueeze(1) | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this different from Float8Tensor, can we use the same scale calculation logic as we use there? |
||||
| scaled_data = w_sparse / scale_expanded | ||||
| scaled_data = scaled_data.clamp(-448.0, 448.0) | ||||
| fp8_data = scaled_data.to(float8_dtype).contiguous() | ||||
|
|
||||
| if not torch.isfinite(fp8_data.to(torch.float32)).all(): | ||||
| raise ValueError("fp8_data contains NaN/Inf after quantization") | ||||
|
|
||||
| # Store fp8 data in both dense and compressed formats | ||||
| fp8_data_fp16 = fp8_data.to(torch.float16) | ||||
|
|
||||
| fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should use the torchao cutlass packing kernels here, not the default torch ones:
|
||||
|
|
||||
| return cls( | ||||
| fp8_data, # dense for dequantization | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we shouldn't be storing both the dense data and the compressed data, we should be storing the sparse specified values and the sparse metadata. |
||||
| fp8_compressed, # compressed for matmul | ||||
| scale, | ||||
| block_size, | ||||
| original_shape=w.shape, | ||||
| dtype=w.dtype, | ||||
| float8_dtype=float8_dtype, | ||||
| ) | ||||
|
|
||||
| def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should multiply by identity matrix to dequantize, like we do here:
|
||||
| if output_dtype is None: | ||||
| output_dtype = self.dtype | ||||
|
|
||||
| # Use dense fp8 data | ||||
| qdata_fp = self.qdata.to(output_dtype) | ||||
| scale_expanded = self.scale.view(-1, 1).to(output_dtype) | ||||
| return qdata_fp * scale_expanded | ||||
|
|
||||
|
|
||||
| implements = Float8SemiSparseTensor.implements | ||||
| implements_torch_function = Float8SemiSparseTensor.implements_torch_function | ||||
|
|
||||
|
|
||||
| @implements(aten.linear.default) | ||||
| @implements_torch_function(torch.nn.functional.linear) | ||||
| def _(func, types, args, kwargs): | ||||
| activation_tensor, weight_tensor, bias = ( | ||||
| args[0], | ||||
| args[1], | ||||
| args[2] if len(args) > 2 else None, | ||||
| ) | ||||
|
|
||||
| assert isinstance(weight_tensor, Float8SemiSparseTensor) | ||||
| assert activation_tensor.dim() == 2, "Only 2D input supported" | ||||
| assert activation_tensor.shape[-1] == weight_tensor.original_shape[1] | ||||
|
|
||||
| x_scales = _choose_qparams_affine_floatx(activation_tensor, ebits=4, mbits=3) | ||||
| w_scales = weight_tensor.scale | ||||
|
|
||||
| # Global normalizer to prevent overflow | ||||
| global_scale = (x_scales.max() * w_scales.max()).sqrt().clamp(min=0.01) | ||||
| x_scales_adj = (x_scales.unsqueeze(1) / global_scale).to(torch.float32) | ||||
| scaled_x = (activation_tensor.to(torch.float32) / x_scales_adj).clamp(-448.0, 448.0) | ||||
| x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn) | ||||
|
|
||||
| # MatMul | ||||
| x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input( | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should use the torchao cutlass fp8 kernels, which fuse in scale multiplication here. See
|
||||
| x_vals_fp8.to(torch.float16) | ||||
| ) | ||||
| y_fp16 = torch.matmul(x_padded, weight_tensor.qdata_compressed.t()) | ||||
| y = y_fp16[: activation_tensor.shape[0], :].to(torch.float32) | ||||
|
|
||||
| # Restore scale | ||||
| w_scales_fp32 = w_scales.to(torch.float32) | ||||
| y = y * (x_scales_adj * w_scales_fp32.unsqueeze(0) * global_scale) | ||||
| y = y.to(activation_tensor.dtype).contiguous() | ||||
|
|
||||
| if bias is not None: | ||||
| y += bias | ||||
| return y | ||||
|
|
||||
|
|
||||
| @implements(aten.slice.Tensor) | ||||
| def _(func, types, args, kwargs): | ||||
| """Slice operation - not supported for compressed sparse format""" | ||||
| # TODO: Build this tensor utility operation | ||||
| raise NotImplementedError( | ||||
| "Slicing not supported for Float8SemiSparseTensor. " | ||||
| "Decompress first using dequantize() if needed." | ||||
| ) | ||||
|
|
||||
|
|
||||
| @implements(aten.select.int) | ||||
| def _(func, types, args, kwargs): | ||||
| """Select operation - not supported for compressed sparse format""" | ||||
| # TODO: Build this tensor utility operation | ||||
| raise NotImplementedError( | ||||
| "Select not supported for Float8SemiSparseTensor. " | ||||
| "Decompress first using dequantize() if needed." | ||||
| ) | ||||
|
|
||||
|
|
||||
| Float8SemiSparseTensor.__module__ = "torchao.quantization" | ||||
| torch.serialization.add_safe_globals([Float8SemiSparseTensor]) | ||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to add to
Float8DynamicActivationFloat8WeightConfig?