Skip to content

Commit 3a325f3

Browse files
committed
Support PLAIN_INT32 for AWQ on Intel GPU
1 parent fabebb2 commit 3a325f3

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

test/prototype/test_awq.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def forward(self, x):
5151
devices.append("cuda")
5252

5353

54+
if torch.xpu.is_available():
55+
devices.append("xpu")
56+
57+
5458
class TestAWQ(TestCase):
5559
def test_awq_config(self):
5660
base_config = Int4WeightOnlyConfig()
@@ -79,6 +83,10 @@ def test_awq_functionality(self, device):
7983
# baseline quantization
8084
if device == "cuda":
8185
base_config = Int4WeightOnlyConfig(group_size=group_size)
86+
elif device == "xpu":
87+
base_config = Int4WeightOnlyConfig(
88+
group_size=group_size, int4_packing_format="plain_int32"
89+
)
8290
elif device == "cpu":
8391
base_config = Int4WeightOnlyConfig(
8492
group_size=group_size, int4_packing_format="opaque"
@@ -137,6 +145,10 @@ def test_awq_loading(self, device):
137145
# calibrate
138146
if device == "cuda":
139147
base_config = Int4WeightOnlyConfig(group_size=group_size)
148+
elif device == "xpu":
149+
base_config = Int4WeightOnlyConfig(
150+
group_size=group_size, int4_packing_format="plain_int32"
151+
)
140152
elif device == "cpu":
141153
base_config = Int4WeightOnlyConfig(
142154
group_size=group_size, int4_packing_format="opaque"
@@ -198,6 +210,10 @@ def test_awq_loading_vllm(self, device):
198210
# calibrate
199211
if device == "cuda":
200212
base_config = Int4WeightOnlyConfig(group_size=group_size)
213+
elif device == "xpu":
214+
base_config = Int4WeightOnlyConfig(
215+
group_size=group_size, int4_packing_format="plain_int32"
216+
)
201217
elif device == "cpu":
202218
base_config = Int4WeightOnlyConfig(
203219
group_size=group_size, int4_packing_format="opaque"

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Int4WeightOnlyConfig,
2020
quantize_,
2121
)
22+
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
2223
from torchao.quantization.utils import compute_error
2324
from torchao.utils import (
2425
torch_version_at_least,
@@ -77,6 +78,25 @@ def test_module_path(self, dtype):
7778
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
7879
)
7980

81+
def test_activation_prescaling(self):
82+
dtype = torch.bfloat16
83+
device = "xpu"
84+
input = torch.randn(1, 128, dtype=dtype, device=device)
85+
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
86+
original = linear(input)
87+
quantize_(linear, get_config(128))
88+
qw = linear.weight
89+
assert isinstance(qw, SupportsActivationPreScaling), (
90+
"Expected int4 tensor supports activation prescaling"
91+
)
92+
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
93+
_ACT_PRE_SCALE = 2
94+
qw.act_pre_scale = _ACT_PRE_SCALE
95+
quantized = linear(input)
96+
97+
# making sure activation pre scaling is successfully applied to the activation
98+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
99+
80100

81101
instantiate_parametrized_tests(Int4PlainInt32Tensor)
82102

0 commit comments

Comments
 (0)