Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to add to Float8DynamicActivationFloat8WeightConfig?

Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch.testing._internal import common_utils

from torchao.prototype.quantization.quantize_.workflows.float8.float8_semisparse_tensor import (
Float8SemiSparseTensor,
)
from torchao.testing.utils import TorchAOIntegrationTestCase


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestFloat8SemiSparseTensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()
torch.manual_seed(42)

# Use 512x512 for 2:4 compatibility (multiples of 32)
self.shape = (512, 512)
self.dtype = torch.bfloat16
self.block_size = [1, 512]
self.weight_fp = torch.randn(*self.shape, dtype=self.dtype, device="cuda")

def test_creation_and_shape(self):
"""Test tensor creation and shape preservation"""
tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)

self.assertEqual(tensor.shape, self.shape)
self.assertEqual(tensor.original_shape, self.shape)
self.assertEqual(tensor.scale.shape[0], self.shape[0])

def test_sparsity_pattern(self):
"""Test 2:4 sparsity pattern is maintained"""
tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)
dequantized = tensor.dequantize()

# Check 2:4 pattern
reshaped = dequantized.reshape(-1, 4)
zeros_per_group = (reshaped == 0).sum(dim=1)
valid_groups = (zeros_per_group == 2).sum().item()
total_groups = zeros_per_group.numel()

self.assertGreaterEqual(valid_groups / total_groups, 0.99)

def test_dequantization_accuracy(self):
"""Test dequantization error is reasonable"""
tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)
dequantized = tensor.dequantize()

# Apply same pruning to original for fair comparison
w_sparse = self.weight_fp.detach().clone()
pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2]
w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0)

# Norm-based metrics for numerical stability
error = (dequantized - w_sparse).abs()
max_error = error.max()
mean_error = error.mean()

# Relative error using non-zero elements
non_zero_mask = w_sparse != 0
if non_zero_mask.any():
rel_error = (error[non_zero_mask] / w_sparse[non_zero_mask].abs()).mean()
self.assertLess(rel_error.item(), 0.1)

self.assertLess(max_error.item(), 1.0)
self.assertLess(mean_error.item(), 0.1)

def test_invalid_dimensions(self):
"""Test dimension validation"""
# Not multiple of 32
invalid_weight = torch.randn(100, 100, dtype=self.dtype, device="cuda")

with self.assertRaises(ValueError):
Float8SemiSparseTensor.from_hp(invalid_weight, [1, 100])

def test_device(self):
"""Test if device handler work"""
# CPU tensor should be rejected
cpu_weight = torch.randn(*self.shape, dtype=self.dtype, device="cpu")
with self.assertRaises(ValueError):
Float8SemiSparseTensor.from_hp(cpu_weight, self.block_size)

# CUDA tensor components should all be on CUDA
tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)
self.assertEqual(tensor.qdata.device, tensor.qdata_compressed.device)
self.assertEqual(tensor.qdata.device, tensor.scale.device)
self.assertTrue(tensor.qdata.is_cuda)

def test_w8a8_dynamic_activation(self):
"""Test W8A8-FP-CSR with dynamic activation quantization"""
weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)

batch_size = 32
in_features = self.shape[1]
activation = torch.randn(
batch_size, in_features, dtype=self.dtype, device="cuda"
)
output = torch.nn.functional.linear(activation, weight_tensor)

expected_shape = (batch_size, self.shape[0])
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, self.dtype)
self.assertFalse(output.isnan().any())
self.assertFalse(output.isinf().any())

def test_linear_with_bias(self):
"""Test linear operation with bias"""
weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)
activation = torch.randn(32, self.shape[1], dtype=self.dtype, device="cuda")
bias = torch.randn(self.shape[0], dtype=self.dtype, device="cuda")

output = torch.nn.functional.linear(activation, weight_tensor, bias)

self.assertEqual(output.shape, (32, self.shape[0]))
self.assertEqual(output.dtype, self.dtype)
self.assertFalse(output.isnan().any())
self.assertFalse(output.isinf().any())

def test_batched_input(self):
"""Test 3D batched input"""
weight_tensor = Float8SemiSparseTensor.from_hp(self.weight_fp, self.block_size)
batch_dims = (4, 8)
activation = torch.randn(
*batch_dims, self.shape[1], dtype=self.dtype, device="cuda"
)

output = torch.nn.functional.linear(activation, weight_tensor)

expected_shape = (*batch_dims, self.shape[0])
self.assertEqual(output.shape, expected_shape)
self.assertEqual(output.dtype, self.dtype)
self.assertFalse(output.isnan().any())

def test_zero_weight_validation(self):
"""Test scale validation with zero weights"""
zero_weight = torch.zeros(*self.shape, dtype=self.dtype, device="cuda")

with self.assertRaises(ValueError):
Float8SemiSparseTensor.from_hp(zero_weight, self.block_size)


if __name__ == "__main__":
common_utils.run_tests()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete? we don't want this to be in prototype I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be add to the init file without the prototype in path

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .float8.float8_semisparse_tensor import (
Float8SemiSparseTensor,
)

__all__ = [
"Float8SemiSparseTensor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT, to_sparse_semi_structured

from torchao.quantization.quant_primitives import (
_choose_qparams_affine_floatx,
)
from torchao.utils import TorchAOBaseTensor

__all__ = ["Float8SemiSparseTensor"]

aten = torch.ops.aten


class Float8SemiSparseTensor(TorchAOBaseTensor):
"""
W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment looks wrong, CSR is compressed sparse row and it's not the sparse format used here (2:4 sparsity)


Tensor Attributes:
qdata: float8 data in dense format
qdata_compressed: SparseSemiStructuredTensor (compressed format for matmul)
scale: scale factors for dequantization

Non-Tensor Attributes:
block_size: block size for quantization granularity
original_shape: original uncompressed shape
float8_dtype: float8 dtype variant
"""

tensor_data_names = ["qdata", "qdata_compressed", "scale"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think quantized_sparse_data and quantized_sparse_metadata would be better here for variable names.

quantized_sparse_data holds the specified values and quantized_sparse_metadata holds the sparsity metadata.

tensor_attribute_names = ["block_size", "original_shape"]
optional_tensor_attribute_names = ["dtype"]

def __new__(
cls: type,
qdata: torch.Tensor,
qdata_compressed: torch.Tensor,
scale: torch.Tensor,
block_size: list[int],
original_shape: tuple[int, ...],
dtype=None,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
):
kwargs = {
"device": qdata.device,
"dtype": dtype or scale.dtype,
"requires_grad": False,
}
return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs)

def __init__(
self,
qdata: torch.Tensor,
qdata_compressed: torch.Tensor,
scale: torch.Tensor,
block_size: list[int],
original_shape: tuple[int, ...],
dtype=None,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
):
super().__init__()
self.qdata = qdata # dense fp8 for dequantization
self.qdata_compressed = qdata_compressed # compressed for matmul
self.scale = scale
self.block_size = block_size
self.original_shape = original_shape
self.float8_dtype = float8_dtype

def __repr__(self):
return (
f"{self.__class__.__name__}("
f"qdata_shape={self.qdata.shape}, {self.scale=}, "
f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})"
)

@property
def qdata_fp8(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

"""For test compatibility"""
return self.qdata

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: list[int],
float8_dtype: torch.dtype = torch.float8_e4m3fn,
):
if w.dim() != 2 or len(block_size) != 2:
raise ValueError("Expected 2D tensor and block_size length 2")

if not w.is_cuda:
raise ValueError("Semi-sparse layout requires CUDA tensors")

# Verify dimensions are compatible with 2:4 compression
rows, cols = w.shape
if rows % 32 != 0 or cols % 32 != 0:
raise ValueError(
"Tensor dimensions must be multiples of 32 for CUDA sparse compression"
)

# Validate block_size
if not all(bs > 0 for bs in block_size):
raise ValueError(f"block_size must be positive, got {block_size}")

if rows % block_size[0] != 0 or cols % block_size[1] != 0:
raise ValueError(
f"Dimensions {w.shape} must be divisible by block_size {block_size}"
)

# Apply 2:4 sparsity pruning
with torch.no_grad():
w_sparse = w.clone()

pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use this util:

def mask_creator(
here

w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0)

# Check for all-zero (sparsity=1) tensor
if w_sparse.abs().max() == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be supported actually? I don't see why we should error here.

raise ValueError("Input tensor is all zeros after pruning")

scale = _choose_qparams_affine_floatx(w_sparse, ebits=4, mbits=3)

if not torch.isfinite(scale).all():
raise ValueError("Scale contains NaN or inf values")
if not (scale > 0).all():
raise ValueError(f"Scale contains non-positive values: min={scale.min()}")

scale_expanded = scale.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different from Float8Tensor, can we use the same scale calculation logic as we use there?

scaled_data = w_sparse / scale_expanded
scaled_data = scaled_data.clamp(-448.0, 448.0)
fp8_data = scaled_data.to(float8_dtype).contiguous()

if not torch.isfinite(fp8_data.to(torch.float32)).all():
raise ValueError("fp8_data contains NaN/Inf after quantization")

# Store fp8 data in both dense and compressed formats
fp8_data_fp16 = fp8_data.to(torch.float16)

fp8_compressed = to_sparse_semi_structured(fp8_data_fp16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the torchao cutlass packing kernels here, not the default torch ones:

sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense)


return cls(
fp8_data, # dense for dequantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't be storing both the dense data and the compressed data, we should be storing the sparse specified values and the sparse metadata.

fp8_compressed, # compressed for matmul
scale,
block_size,
original_shape=w.shape,
dtype=w.dtype,
float8_dtype=float8_dtype,
)

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should multiply by identity matrix to dequantize, like we do here:

if output_dtype is None:
output_dtype = self.dtype

# Use dense fp8 data
qdata_fp = self.qdata.to(output_dtype)
scale_expanded = self.scale.view(-1, 1).to(output_dtype)
return qdata_fp * scale_expanded


implements = Float8SemiSparseTensor.implements
implements_torch_function = Float8SemiSparseTensor.implements_torch_function


@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
activation_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)

assert isinstance(weight_tensor, Float8SemiSparseTensor)
assert activation_tensor.dim() == 2, "Only 2D input supported"
assert activation_tensor.shape[-1] == weight_tensor.original_shape[1]

x_scales = _choose_qparams_affine_floatx(activation_tensor, ebits=4, mbits=3)
w_scales = weight_tensor.scale

# Global normalizer to prevent overflow
global_scale = (x_scales.max() * w_scales.max()).sqrt().clamp(min=0.01)
x_scales_adj = (x_scales.unsqueeze(1) / global_scale).to(torch.float32)
scaled_x = (activation_tensor.to(torch.float32) / x_scales_adj).clamp(-448.0, 448.0)
x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn)

# MatMul
x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the torchao cutlass fp8 kernels, which fuse in scale multiplication here.

See

def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias):

x_vals_fp8.to(torch.float16)
)
y_fp16 = torch.matmul(x_padded, weight_tensor.qdata_compressed.t())
y = y_fp16[: activation_tensor.shape[0], :].to(torch.float32)

# Restore scale
w_scales_fp32 = w_scales.to(torch.float32)
y = y * (x_scales_adj * w_scales_fp32.unsqueeze(0) * global_scale)
y = y.to(activation_tensor.dtype).contiguous()

if bias is not None:
y += bias
return y


@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
"""Slice operation - not supported for compressed sparse format"""
# TODO: Build this tensor utility operation
raise NotImplementedError(
"Slicing not supported for Float8SemiSparseTensor. "
"Decompress first using dequantize() if needed."
)


@implements(aten.select.int)
def _(func, types, args, kwargs):
"""Select operation - not supported for compressed sparse format"""
# TODO: Build this tensor utility operation
raise NotImplementedError(
"Select not supported for Float8SemiSparseTensor. "
"Decompress first using dequantize() if needed."
)


Float8SemiSparseTensor.__module__ = "torchao.quantization"
torch.serialization.add_safe_globals([Float8SemiSparseTensor])
Loading