Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Oct 29, 2025

Summary:
Introduce new W8A8-FP-CSR quantization API, Float8SemiSparseTensor, which specializes in semi-sparse pattern using cuSPARSELt accelerations (https://docs.nvidia.com/cuda/cusparselt/)

Related Issue/PR: #2752

Future Plan:
This PR only introduces core operations (quantization/dequantization). For better API support, we have to introduce tensor utility operations like indexing and slicing.

Test Plan:
test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3258

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 9 New Failures

As of commit f5f7a17 with merge base 3577306 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2025
@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Oct 29, 2025

@jcaip could you please check this PR?

@jcaip
Copy link
Contributor

jcaip commented Oct 31, 2025

cc @namgyu-youn

Can you split this into two PRs? one for int8 and one for float8?

In general I don't think we want to introduce weight-only sparsity configs for int8 and float8 because we don't have mixed-dtype kernel support currently. The only kernels we have are for int8 x int8 2:4 sparse and fp8 x fp8 2:4 sparse.

I would like Int8SemiSparseTensor though, but I think it should live in prototype until we have a user for it.

Also cc @bbeckca who has been working on fp8xfp8 2:4 sparse tensor subclass migration in #3182.

@jerryzh168
Copy link
Contributor

cc @namgyu-youn

Can you split this into two PRs? one for int8 and one for float8?

In general I don't think we want to introduce weight-only sparsity configs for int8 and float8 because we don't have mixed-dtype kernel support currently. The only kernels we have are for int8 x int8 2:4 sparse and fp8 x fp8 2:4 sparse.

I would like Int8SemiSparseTensor though, but I think it should live in prototype until we have a user for it.

Also cc @bbeckca who has been working on fp8xfp8 2:4 sparse tensor subclass migration in #3182.

@jcaip if we want to move int8 2:4 sparse to prototype, then we don't need to migrate the tensor I think

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Oct 31, 2025

Okay, then I'll address only W8A8-INT W8A8-FP here and keep file structure at the prototype.

@jcaip
Copy link
Contributor

jcaip commented Oct 31, 2025

cc @namgyu-youn I talked to @bbeckca and I think your PR is closer so lets use it instead.
Can you remove the int8 changes then and I will give this a review. Thanks for picking this up!

@namgyu-youn namgyu-youn marked this pull request as draft November 2, 2025 05:00
@namgyu-youn namgyu-youn changed the title Introduce new Semi-sparse quantization APIs Introduce new W8A8-FP-CSR quantitzation API Nov 2, 2025
@namgyu-youn namgyu-youn marked this pull request as ready for review November 2, 2025 09:15
@namgyu-youn
Copy link
Contributor Author

cc @jcaip to request review, thanks.

Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @namgyu-youn

I think there's a bit of confusion on what the tensor subclass should be storing and how to do the op overload.

Please take a look at https://github.com/pytorch/ao/pull/3182/files#diff-afc7dd21d2b704181a6fd55be989426c0217a2bbfb694af9eb9746239ec462ed for the appropriate logic / ops to be called.


class Float8SemiSparseTensor(TorchAOBaseTensor):
"""
W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment looks wrong, CSR is compressed sparse row and it's not the sparse format used here (2:4 sparsity)

float8_dtype: float8 dtype variant
"""

tensor_data_names = ["qdata", "qdata_compressed", "scale"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think quantized_sparse_data and quantized_sparse_metadata would be better here for variable names.

quantized_sparse_data holds the specified values and quantized_sparse_metadata holds the sparsity metadata.

)

@property
def qdata_fp8(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0)

# Check for all-zero (sparsity=1) tensor
if w_sparse.abs().max() == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be supported actually? I don't see why we should error here.

with torch.no_grad():
w_sparse = w.clone()

pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use this util:

def mask_creator(
here

# Store fp8 data in both dense and compressed formats
fp8_data_fp16 = fp8_data.to(torch.float16)

fp8_compressed = to_sparse_semi_structured(fp8_data_fp16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the torchao cutlass packing kernels here, not the default torch ones:

sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense)

if not (scale > 0).all():
raise ValueError(f"Scale contains non-positive values: min={scale.min()}")

scale_expanded = scale.unsqueeze(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different from Float8Tensor, can we use the same scale calculation logic as we use there?

fp8_compressed = to_sparse_semi_structured(fp8_data_fp16)

return cls(
fp8_data, # dense for dequantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't be storing both the dense data and the compressed data, we should be storing the sparse specified values and the sparse metadata.

float8_dtype=float8_dtype,
)

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should multiply by identity matrix to dequantize, like we do here:

x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn)

# MatMul
x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the torchao cutlass fp8 kernels, which fuse in scale multiplication here.

See

def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete? we don't want this to be in prototype I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be add to the init file without the prototype in path

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to add to Float8DynamicActivationFloat8WeightConfig?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants