-
Notifications
You must be signed in to change notification settings - Fork 359
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| quant_min=self.int8_min, | ||
| quant_max=self.int8_max, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can omit these two args if these are the same as default (-128, 127)
| ) | ||
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_quantization_shapes(self, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:
| @common_utils.parametrize("mode", ["dynamic", "weight-only"]) |
also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.
| if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None: | ||
| # INT8 × INT8 (static) | ||
| scale = act_quant_kwargs.static_scale | ||
| zero_point = torch.zeros_like(scale, dtype=torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think user should specify static_zero_point as well
but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)
| # Cast fp16 scale to float | ||
| intermediate_dtype = ( | ||
| torch.float if x_scales.dtype == torch.half else x_scales.dtype | ||
| ) | ||
| # Note: CUDA doesn't support int32/int64 matmul, so we convert to float | ||
| # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' | ||
| # This may introduce minor numerical differences compared to int arithmetic | ||
| y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) | ||
|
|
||
| # Apply activation scale | ||
| is_per_tensor_act = x_scales.numel() == 1 | ||
| if is_per_tensor_act: | ||
| y_dot.mul_(x_scales.to(intermediate_dtype)) | ||
| else: | ||
| # For block-wise activation scale, reshape to match y_dot | ||
| x_scales_reshaped = x_scales.view(y_dot.shape[0], -1) | ||
| y_dot.mul_(x_scales_reshaped.to(intermediate_dtype)) | ||
|
|
||
| # Apply weight scale | ||
| is_per_tensor_weight = w_scales.numel() == 1 | ||
| if is_per_tensor_weight: | ||
| result = y_dot.mul_(w_scales.to(intermediate_dtype)) | ||
| else: | ||
| # Per-row weight scale - transpose and broadcast | ||
| w_scales_broadcast = w_scales.t().expand_as(y_dot) | ||
| result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype)) | ||
|
|
||
| # Reshape back to original shape | ||
| result = result.view(*x_vals.shape[:-1], result.shape[-1]) | ||
| result = result.to(activation_tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should follow:
ao/torchao/dtypes/uintx/plain_layout.py
Line 281 in e9c7bea
| def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should
- split the static quant support to separate PR
- follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation
this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think
| aten = torch.ops.aten | ||
|
|
||
| # Unsupported case for now, this would be 1 scale per data element | ||
| # Per-tensor quantization (scalar scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change related?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is updated to support more granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So maybe it's better to move this util function to a common place?
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| def test_per_row_scale_shape(self, dtype, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test like this one
| def test_fp8_linear_variants( |
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @common_utils.parametrize("has_bias", [True, False]) | ||
| def test_weight_only_linear_with_bias(self, dtype, has_bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can probably be merged into the linear varaints test as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I think the tensor changes looks good, but need to make a linear_variants tests to make sure we cover different aspects of things (e.g. compile), see comments inline
can you also do a e2e perf check with https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py to make sure the performance are the same before and after change for vit model?
also add a kernel check might be useful to make sure we don't regress things:
| def test_expected_gpu_kernel_fbgemm(self): |
|
Updated logs:
|
|
Hi @namgyu-youn Do you plan to submit another PR for static quantization? We also need static quantization for SmoothQuant. So, we are wondering if you have a plan or we should consider adding it ourselves. Thanks. CC @cyxlily |
Yeah, static quantization support using static/dynamic flags is planned; I hope to show it to your team in the foreseeable future. Also, in the SmoothQuant case, validating its support for the new quantization APIs (below) has higher priority, I think. Could you look into it?
|
Thanks. Looking forward to it. If there is anything we can help with, please let us know.
By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant. |
Oh yes, it was a typo (W8A16 is right), and W4A16-INT ( Because current AWQ/SmoothQuant test is only working with old APIs (version 1), we can replace it with new APIs like |
I see. Thanks. We will evaluate that. |
|
Hi @namgyu-youn May I know if you have a timeline to land this? Thanks. |
| ((32, 128), 64, 256), # 3D | ||
| ], | ||
| ) | ||
| def test_int8_linear_quantization_accuracy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this one be combined with test_int8_linear_variants as well?
| with self.assertRaises(NotImplementedError): | ||
| _ = dummy.weight[::2] | ||
|
|
||
| def test_index_select(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the tests should be modified to use config as well
| kernels = {} | ||
|
|
||
| # Check for Triton kernels | ||
| if "torch.ops.triton" in code[0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should add some asserts I think
- Configs are updated to global variants
|
Updated logs:
|
Summary:
Introduce a new tensor subclass API. Main features are
Int8Tensor: Main API, which handles quantization and dequantization operationsThis api is integrated to global variants (
Int8WeightOnlyConfig,Int8DynamicActivationInt8WeightConfig) usingversion, and not defined as a default.Related Issue/PR:
This is reopened PR for #3038
Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Performance:
The following are the results of https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32:
torch.compiletorch.compile