-
Notifications
You must be signed in to change notification settings - Fork 273
Adding uint4 dtype implementation #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
3e7e466
Adding uint4 dtype implementation
jerryzh168 a11039b
Update on "Adding uint4 dtype implementation"
jerryzh168 1fc4660
Update on "Adding uint4 dtype implementation"
jerryzh168 19f3d0d
Update on "Adding uint4 dtype implementation"
jerryzh168 734c705
Update on "Adding uint4 dtype implementation"
jerryzh168 a9d6cdc
Update on "Adding uint4 dtype implementation"
jerryzh168 1b51eb5
Update on "Adding uint4 dtype implementation"
jerryzh168 b5ce8c6
Update on "Adding uint4 dtype implementation"
jerryzh168 bb483f3
Update on "Adding uint4 dtype implementation"
jerryzh168 9c9084e
Update on "Adding uint4 dtype implementation"
jerryzh168 dfe122d
Update on "Adding uint4 dtype implementation"
jerryzh168 9a22440
Update on "Adding uint4 dtype implementation"
jerryzh168 c69d753
Update on "Adding uint4 dtype implementation"
jerryzh168 6f0aa7d
Update on "Adding uint4 dtype implementation"
jerryzh168 c52b123
Update on "Adding uint4 dtype implementation"
jerryzh168 49c6a43
Update on "Adding uint4 dtype implementation"
jerryzh168 8ab52a7
Update on "Adding uint4 dtype implementation"
jerryzh168 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
import torch | ||
from torchao.dtypes.int4 import UInt4Tensor | ||
import unittest | ||
from unittest import TestCase, main | ||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | ||
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer | ||
|
||
from torch._export import capture_pre_autograd_graph | ||
from torch._export import dynamic_dim | ||
from torch.testing._internal.common_quantization import ( | ||
NodeSpec as ns, | ||
QuantizationTestCase, | ||
) | ||
from torchao.quantization.utils import ( | ||
compute_error, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
replace_with_custom_fn_if_matches_filter, | ||
) | ||
from torch import nn | ||
import copy | ||
|
||
def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): | ||
# assumes symmetric quantization | ||
# assumes axis == 0 | ||
# assumes dense memory format | ||
# TODO(future): relax ^ as needed | ||
|
||
# default setup for affine quantization of activations | ||
eps = torch.finfo(torch.float32).eps | ||
|
||
# get min and max | ||
min_val, max_val = torch.aminmax(x, dim=1) | ||
|
||
# calculate scale and zero point based on min and max | ||
# reference: https://fburl.com/code/srbiybme | ||
min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) | ||
max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) | ||
device = min_val_neg.device | ||
|
||
# reference: https://fburl.com/code/4wll53rk | ||
max_val_pos = torch.max(-min_val_neg, max_val_pos) | ||
scale = max_val_pos / (float(quant_max - quant_min) / 2) | ||
# ensure scale is the same dtype as the original tensor | ||
scale = torch.clamp(scale, min=eps).to(x.dtype) | ||
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) | ||
|
||
# quantize based on qmin/qmax/scale/zp | ||
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63 | ||
x_div = x.transpose(0, 1) / scale | ||
x_round = torch.round(x_div) | ||
x_zp = x_round + zero_point | ||
x_zp = x_zp.transpose(0, 1) | ||
quant = torch.clamp(x_zp, quant_min, quant_max) | ||
if target_dtype == "int4": | ||
quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size()) | ||
else: | ||
quant = quant.to(target_dtype) | ||
|
||
return quant, scale, zero_point | ||
|
||
class _WeightOnlyInt4QuantLinear(torch.nn.Linear): | ||
def __init__(self, *args, **kwargs): | ||
w_int4 = kwargs.pop("w_int4") | ||
scales = kwargs.pop("scales") | ||
super().__init__(*args, **kwargs) | ||
self.w_int4 = w_int4 | ||
self.scales = scales | ||
|
||
def forward(self, x): | ||
# if len(x.shape)<=2: | ||
# y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales | ||
# else: # turn x into 2d tensor, then undo it for y | ||
x_view = x.view(-1, x.shape[-1]) | ||
y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales | ||
y = y.reshape(*x.shape[:-1], -1) | ||
if self.bias is not None: | ||
y += self.bias | ||
return y | ||
|
||
@classmethod | ||
def from_float(cls, mod): | ||
w_fp32 = mod.weight | ||
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( | ||
w_fp32, 0, 15, "int4" | ||
) | ||
# create the new module with a toy size to ensure initialization is fast | ||
fake_in_features, fake_out_features = 8, 8 | ||
new_mod = cls( | ||
fake_in_features, | ||
fake_out_features, | ||
bias=mod.bias is not None, | ||
w_int4=w_int4.t().contiguous(), | ||
scales=scales, | ||
) | ||
new_mod.in_features = mod.in_features | ||
new_mod.out_features = mod.out_features | ||
del new_mod.weight | ||
new_mod.bias = mod.bias | ||
device_to_use = next(mod.parameters()).device | ||
new_mod.to(device_to_use) | ||
return new_mod | ||
|
||
def _apply_weight_only_int4_quant(model): | ||
replace_with_custom_fn_if_matches_filter( | ||
model, | ||
_WeightOnlyInt4QuantLinear.from_float, | ||
lambda mod, fqn: isinstance(mod, torch.nn.Linear), | ||
) | ||
|
||
class TestInt4(QuantizationTestCase): | ||
def test_basic_tensor_ops(self): | ||
x = UInt4Tensor(torch.tensor([ | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
], dtype=torch.uint8)) | ||
self.assertTrue(x.shape, (3, 8)) | ||
# making sure these works | ||
x.to(torch.uint8) | ||
expected = UInt4Tensor(torch.tensor([ | ||
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], | ||
], dtype=torch.uint8)) | ||
self.assertTrue(x[0:1, :] == expected) | ||
expected = UInt4Tensor(torch.tensor([ | ||
[0x23, 0x45], | ||
[0x23, 0x45], | ||
[0x23, 0x45], | ||
], dtype=torch.uint8)) | ||
self.assertTrue(x[:, 2:6] == expected) | ||
|
||
def test_gpu_quant(self): | ||
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: | ||
x = torch.randn(*x_shape) | ||
m = nn.Sequential(nn.Linear(4, 16)) | ||
y_ref = m(x) | ||
_apply_weight_only_int4_quant(m) | ||
y_wo = m(x) | ||
# sqnr = compute_error(y_ref, y_wo) | ||
opt = torch.compile(m, mode="max-autotune") | ||
# make sure it runs | ||
opt(x) | ||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def test_aten_ir(self): | ||
from torch.library import Library, impl | ||
test_lib = Library("test_int4", "DEF") | ||
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") | ||
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") | ||
def quantize_per_tensor_int4( | ||
input: torch.Tensor, | ||
scale: float, | ||
zero_point: int, | ||
) -> torch.Tensor: | ||
inv_scale = 1.0 / scale | ||
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) | ||
|
||
test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") | ||
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") | ||
def dequantize_per_tensor_int4( | ||
input: torch.Tensor, | ||
scale: float, | ||
zero_point: int, | ||
) -> torch.Tensor: | ||
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale | ||
|
||
# class QuantizePerTensorUInt4(torch.autograd.Function): | ||
# @staticmethod | ||
# def forward( | ||
# ctx, | ||
# input: torch.Tensor, | ||
# scale: float, | ||
# zero_point: int, | ||
# ) -> torch.Tensor: | ||
# inv_scale = 1.0 / scale | ||
# return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) | ||
|
||
# class DeQuantizePerTensorUInt4(torch.autograd.Function): | ||
# @staticmethod | ||
# def forward( | ||
# ctx, | ||
# input: torch.Tensor, | ||
# scale: float, | ||
# zero_point: int, | ||
# ) -> torch.Tensor: | ||
# return (input.to(torch.float32) - zero_point) * scale | ||
|
||
class M(torch.nn.Module): | ||
def forward(self, x, y): | ||
return x + y | ||
|
||
example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),) | ||
m = M().eval() | ||
m = capture_pre_autograd_graph(m, example_inputs) | ||
for n in m.graph.nodes: | ||
if n.target == torch.ops.aten.add.Tensor: | ||
with m.graph.inserting_before(n): | ||
q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {}) | ||
dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {}) | ||
n.replace_input_with(n.args[0], dq) | ||
m.recompile() | ||
|
||
# TODO: need more extension points from quant flow side | ||
@unittest.skip("need more extension points from quant flow side") | ||
def test_pt2e_quant(self): | ||
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( | ||
OP_TO_ANNOTATOR, | ||
QuantizationConfig, | ||
) | ||
|
||
class Int4ActQuantizer(Quantizer): | ||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
int4_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-2**3, | ||
quant_max=2**3 - 1, | ||
qscheme=torch.per_tensor_affine, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=observer.default_observer, | ||
) | ||
int8_qspec = QuantizationSpec( | ||
dtype=torch.int8, | ||
quant_min=-128, | ||
quant_max=127, | ||
qscheme=torch.per_tensor_symmetric, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=observer.default_weight_observer, | ||
) | ||
quantization_config = QuantizationConfig( | ||
input_activation=int8_qspec, | ||
weight=int4_qspec, | ||
bias=None, | ||
output_activation=int8_qspec, | ||
) | ||
OP_TO_ANNOTATOR["conv"](model, quantization_config) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = torch.nn.Conv2d(3, 3, 3) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
quantizer = Int4ActQuantizer() | ||
node_occurrence = { | ||
# one for input of the first conv, one for output for the first conv | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, | ||
} | ||
node_list = [ | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.aten.conv2d.default, | ||
torch.ops.quantized_decomposed.quantize_per_tensor.default, | ||
] | ||
example_inputs = (torch.randn(1, 3, 3, 3),) | ||
|
||
# _test_quantizer in PT2EQuantizationTestCase | ||
# resetting dynamo cache | ||
export_with_dynamic_shape = False | ||
torch._dynamo.reset() | ||
m_eager = M().eval() | ||
|
||
# program capture | ||
m = copy.deepcopy(m_eager) | ||
m = capture_pre_autograd_graph( | ||
m, | ||
example_inputs, | ||
constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [], | ||
) | ||
|
||
m = prepare_pt2e(m, quantizer) | ||
# Calibrate | ||
m(*example_inputs) | ||
m = convert_pt2e(m, fold_quantize=True) | ||
|
||
pt2_quant_output = m(*example_inputs) | ||
node_occurrence = { | ||
ns.call_function(k): v for k, v in expected_node_occurrence.items() | ||
} | ||
if expected_node_list is None: | ||
expected_node_list = [] | ||
node_list = [ns.call_function(n) for n in expected_node_list] | ||
self.checkGraphModuleNodes( | ||
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list | ||
) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .int4 import UInt4Tensor | ||
|
||
__all__ = [ | ||
"UInt4Tensor" | ||
] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.