-
Notifications
You must be signed in to change notification settings - Fork 359
Introduce new W8A8-FP-CSR quantitzation API #3258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3258
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 9 New FailuresAs of commit f5f7a17 with merge base 3577306 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@jcaip could you please check this PR? |
|
cc @namgyu-youn Can you split this into two PRs? one for int8 and one for float8? In general I don't think we want to introduce weight-only sparsity configs for int8 and float8 because we don't have mixed-dtype kernel support currently. The only kernels we have are for int8 x int8 2:4 sparse and fp8 x fp8 2:4 sparse. I would like Int8SemiSparseTensor though, but I think it should live in prototype until we have a user for it. Also cc @bbeckca who has been working on fp8xfp8 2:4 sparse tensor subclass migration in #3182. |
@jcaip if we want to move int8 2:4 sparse to prototype, then we don't need to migrate the tensor I think |
|
Okay, then I'll address only |
|
cc @namgyu-youn I talked to @bbeckca and I think your PR is closer so lets use it instead. |
|
cc @jcaip to request review, thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @namgyu-youn
I think there's a bit of confusion on what the tensor subclass should be storing and how to do the op overload.
Please take a look at https://github.com/pytorch/ao/pull/3182/files#diff-afc7dd21d2b704181a6fd55be989426c0217a2bbfb694af9eb9746239ec462ed for the appropriate logic / ops to be called.
|
|
||
| class Float8SemiSparseTensor(TorchAOBaseTensor): | ||
| """ | ||
| W8A8-FP-CSR: float8 quantized tensor with 2:4 semi-structured sparsity layout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: comment looks wrong, CSR is compressed sparse row and it's not the sparse format used here (2:4 sparsity)
| float8_dtype: float8 dtype variant | ||
| """ | ||
|
|
||
| tensor_data_names = ["qdata", "qdata_compressed", "scale"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think quantized_sparse_data and quantized_sparse_metadata would be better here for variable names.
quantized_sparse_data holds the specified values and quantized_sparse_metadata holds the sparsity metadata.
| ) | ||
|
|
||
| @property | ||
| def qdata_fp8(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this?
| w_sparse.view(-1, 4).scatter_(1, pruning_inds, value=0) | ||
|
|
||
| # Check for all-zero (sparsity=1) tensor | ||
| if w_sparse.abs().max() == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be supported actually? I don't see why we should error here.
| with torch.no_grad(): | ||
| w_sparse = w.clone() | ||
|
|
||
| pruning_inds = w_sparse.abs().view(-1, 4).argsort(dim=1)[:, :2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use this util:
Line 101 in 315e9b4
| def mask_creator( |
| # Store fp8 data in both dense and compressed formats | ||
| fp8_data_fp16 = fp8_data.to(torch.float16) | ||
|
|
||
| fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use the torchao cutlass packing kernels here, not the default torch ones:
| sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(dense) |
| if not (scale > 0).all(): | ||
| raise ValueError(f"Scale contains non-positive values: min={scale.min()}") | ||
|
|
||
| scale_expanded = scale.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this different from Float8Tensor, can we use the same scale calculation logic as we use there?
| fp8_compressed = to_sparse_semi_structured(fp8_data_fp16) | ||
|
|
||
| return cls( | ||
| fp8_data, # dense for dequantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't be storing both the dense data and the compressed data, we should be storing the sparse specified values and the sparse metadata.
| float8_dtype=float8_dtype, | ||
| ) | ||
|
|
||
| def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should multiply by identity matrix to dequantize, like we do here:
| def get_plain(self): |
| x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn) | ||
|
|
||
| # MatMul | ||
| x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use the torchao cutlass fp8 kernels, which fuse in scale multiplication here.
See
| def _linear_fp8_act_fp8_weight_sparse_cutlass_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete? we don't want this to be in prototype I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be add to the init file without the prototype in path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to add to Float8DynamicActivationFloat8WeightConfig?
Summary:
Introduce new W8A8-FP-CSR quantization API,
Float8SemiSparseTensor, which specializes in semi-sparse pattern using cuSPARSELt accelerations (https://docs.nvidia.com/cuda/cusparselt/)Related Issue/PR: #2752
Future Plan:
This PR only introduces core operations (quantization/dequantization). For better API support, we have to introduce tensor utility operations like indexing and slicing.
Test Plan:
test/prototype/quantization/quantize_/float8/test_float8_semisparse_tensor.py