Skip to content

Commit b3deb16

Browse files
authored
Fix torch.intx support in FakeQuantizeConfig (#1544)
**Summary:** Fixes the following error when passing `torch.intx` to `FakeQuantizeConfig`. These dtypes were introduced in PyTorch 2.6+: ``` ValueError: Unsupported dtype 'torch.int4', choose from [torch.int8, torch.uint8, <TorchAODType.INT1: 1>, <TorchAODType.INT2: 2>, <TorchAODType.INT3: 3>, <TorchAODType.INT4: 4>, <TorchAODType.INT5: 5>, <TorchAODType.INT6: 6>, <TorchAODType.INT7: 7>, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7] ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_torch_intx
1 parent de5c6e1 commit b3deb16

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

test/quantization/test_qat.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from torchao.utils import (
6464
TORCH_VERSION_AT_LEAST_2_3,
6565
TORCH_VERSION_AT_LEAST_2_4,
66+
TORCH_VERSION_AT_LEAST_2_6,
6667
)
6768

6869
# TODO: put this in a common test utils file
@@ -1327,6 +1328,26 @@ def test_quantize_api_convert_path(self):
13271328
baseline_out = baseline_model(*x2)
13281329
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
13291330

1331+
@unittest.skipIf(
1332+
not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower"
1333+
)
1334+
def test_fake_quantize_config_torch_intx(self):
1335+
"""
1336+
Test that `FakeQuantizeConfig` works with torch.intx.
1337+
"""
1338+
group_size = 16
1339+
config1 = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1340+
config2 = FakeQuantizeConfig(torch.int4, group_size=group_size)
1341+
linear1 = FakeQuantizedLinear(32, 64, weight_config=config1)
1342+
linear2 = FakeQuantizedLinear(32, 64, weight_config=config2)
1343+
linear2.weight = linear1.weight
1344+
torch.manual_seed(self.SEED)
1345+
x = torch.randn((1, 32)).to(torch.float)
1346+
x2 = copy.deepcopy(x)
1347+
out1 = linear1(*x)
1348+
out2 = linear2(*x2)
1349+
torch.testing.assert_close(out1, out2, atol=0, rtol=0)
1350+
13301351

13311352
if __name__ == "__main__":
13321353
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.utils import (
1919
TORCH_VERSION_AT_LEAST_2_3,
2020
TORCH_VERSION_AT_LEAST_2_5,
21+
TORCH_VERSION_AT_LEAST_2_6,
2122
_is_float8_type,
2223
_register_custom_op,
2324
)
@@ -162,6 +163,31 @@ class TorchAODType(Enum):
162163
}
163164
)
164165

166+
# torch.intX available only in PyTorch 2.6+
167+
if TORCH_VERSION_AT_LEAST_2_6:
168+
_SUB_BYTE_INT_BOUNDS.update(
169+
{
170+
torch.int1: (-(2**0), 2**0 - 1),
171+
torch.int2: (-(2**1), 2**1 - 1),
172+
torch.int3: (-(2**2), 2**2 - 1),
173+
torch.int4: (-(2**3), 2**3 - 1),
174+
torch.int5: (-(2**4), 2**4 - 1),
175+
torch.int6: (-(2**5), 2**5 - 1),
176+
torch.int7: (-(2**6), 2**6 - 1),
177+
}
178+
)
179+
_DTYPE_TO_BIT_WIDTH.update(
180+
{
181+
torch.int1: 1,
182+
torch.int2: 2,
183+
torch.int3: 3,
184+
torch.int4: 4,
185+
torch.int5: 5,
186+
torch.int6: 6,
187+
torch.int7: 7,
188+
}
189+
)
190+
165191
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
166192
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS)
167193
assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys()

0 commit comments

Comments
 (0)