Skip to content
Open
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ TorchAO is an easy to use quantization library for native PyTorch. TorchAO works

Check out our [docs](https://docs.pytorch.org/ao/main/) for more details!

## Third-party Pipeline Status

| Backend | Inference |
| ----------- | -------------------------------------------------------------------------------------------------------------------- |
| Ascend NPU | [![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml) |

## 🚀 Quick Start

First, install TorchAO. We recommend installing the latest stable version:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import pytest
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
Expand All @@ -33,9 +33,19 @@ def get_config(group_size):
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class Int4PlainInt32Tensor(TestCase):
_MIN_VER = {
"xpu": "2.8.0",
"npu": "2.7.1",
}

def setUp(self):
min_req = type(self)._MIN_VER.get(self.device_type)
if not torch_version_at_least(min_req):
self.skipTest(
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
)

@parametrize(
"sizes",
[
Expand All @@ -46,24 +56,36 @@ class Int4PlainInt32Tensor(TestCase):
)
@parametrize("dtype", [torch.bfloat16, torch.half])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
device = "xpu"
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
def test_linear(self, device, sizes, dtype, group_size, thresholds):
M, N, K = sizes
if "npu" in device and group_size == K:
pytest.skip(
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
)
threshold = thresholds.get(device.split(":")[0])

input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)
self.assertTrue(compute_error(original, quantized) > threshold)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
if "xpu" in device:
compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)

@parametrize("dtype", [torch.bfloat16, torch.half])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
quantize_(linear, get_config(group_size=128))
def test_module_path(self, device, dtype):
device = self.device_type
K, N, group_size = 128, 256, 128
if "npu" in device:
group_size = 64

linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
quantize_(linear, get_config(group_size))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
Expand All @@ -78,13 +100,22 @@ def test_module_path(self, dtype):
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

def test_activation_prescaling(self):
dtype = torch.bfloat16
device = "xpu"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
def test_activation_prescaling(self, device, dtype, thresholds):
device = self.device_type
if "xpu" in device and dtype == torch.float16:
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")

threshold = thresholds.get(device.split(":")[0])
K, N, group_size = 128, 256, 128
if "npu" in device:
group_size = 64

input = torch.randn(1, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(128))
quantize_(linear, get_config(group_size))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
Expand All @@ -95,10 +126,12 @@ def test_activation_prescaling(self):
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)


instantiate_parametrized_tests(Int4PlainInt32Tensor)
instantiate_device_type_tests(
Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True
)


if __name__ == "__main__":
Expand Down
Loading