Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def forward(self, x):
devices.append("cuda")


if torch.xpu.is_available():
devices.append("xpu")


class TestAWQ(TestCase):
def test_awq_config(self):
base_config = Int4WeightOnlyConfig()
Expand Down Expand Up @@ -79,6 +83,10 @@ def test_awq_functionality(self, device):
# baseline quantization
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "xpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="plain_int32"
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
Expand Down Expand Up @@ -137,6 +145,10 @@ def test_awq_loading(self, device):
# calibrate
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "xpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="plain_int32"
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
Expand Down Expand Up @@ -198,6 +210,10 @@ def test_awq_loading_vllm(self, device):
# calibrate
if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "xpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="plain_int32"
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
Expand Down Expand Up @@ -77,6 +78,25 @@ def test_module_path(self, dtype):
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

def test_activation_prescaling(self):
dtype = torch.bfloat16
device = "xpu"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(128))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
)
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)


instantiate_parametrized_tests(Int4PlainInt32Tensor)

Expand Down
4 changes: 4 additions & 0 deletions torchao/prototype/awq/example.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a test like

def test_activation_prescaling(self):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaowangintel let us add the UT as int4_tensor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ def quantize_and_eval(

if device == "cuda":
base_config = Int4WeightOnlyConfig(group_size=group_size)
elif device == "xpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="plain_int32"
)
elif device == "cpu":
base_config = Int4WeightOnlyConfig(
group_size=group_size, int4_packing_format="opaque"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from typing import List
from typing import List, Optional

import torch

Expand Down Expand Up @@ -38,10 +38,16 @@ class Int4PlainInt32Tensor(TorchAOBaseTensor):
block_size: the block size for quantization, representing the granularity.
shape: shape of the original Tensor

Optional Tensor Data Attributes:
act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present,
we'll multiply activation Tensor with act_pre_scale before applying dynamic
quantization to activation or running quantized mm op

"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]
optional_tensor_data_names = ["act_pre_scale"]

def __new__(
cls,
Expand All @@ -50,21 +56,34 @@ def __new__(
zero_point,
block_size,
shape,
act_pre_scale: Optional[torch.Tensor] = None,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, scale, zero_point, block_size, shape):
def __init__(
self,
qdata,
scale,
zero_point,
block_size,
shape,
act_pre_scale: Optional[torch.Tensor] = None,
):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size
self.act_pre_scale = act_pre_scale

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
if self.act_pre_scale is not None:
s += f", act_pre_scale.shape={self.act_pre_scale.shape}"
return s

@classmethod
def from_hp(
Expand Down Expand Up @@ -122,6 +141,7 @@ def from_hp(
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
act_pre_scale=None,
)


Expand All @@ -148,6 +168,9 @@ def _(func, types, args, kwargs):
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

if weight_tensor.act_pre_scale is not None:
input_tensor = input_tensor * weight_tensor.act_pre_scale

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale = weight_tensor.scale
Expand Down
Loading