Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quantization_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ First we want to lay out the torchao stack::

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
---------------------------------------------------------------------------------------------
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
- scaled int4
- preshuffled (special format to optimize for loading)
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
* - Int8Tensor
- plain

.. note::
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Expand Down
262 changes: 262 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from torch._inductor.utils import run_and_get_code
from torch.testing._internal import common_utils

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase


# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged
class ToyTwoLinearModel(torch.nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
has_bias=False,
dtype=None,
device=None,
):
super().__init__()
self.dtype = dtype
self.device = device
self.linear1 = torch.nn.Linear(
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device
)
self.linear2 = torch.nn.Linear(
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()

self.test_shape = (32, 20)
self.dtype = torch.bfloat16
self.batch_size = 32

torch.manual_seed(42)
self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
self.block_size = list(self.test_shape)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_creation_and_attributes(self, config):
"""Test tensor creation, dtypes, and ranges"""
linear = torch.nn.Linear(
self.test_shape[1],
self.test_shape[0],
bias=False,
dtype=self.dtype,
device="cuda",
)
linear.weight.data = self.weight_fp.cuda()
quantize_(linear, config)

tensor = linear.weight

self.assertEqual(tensor.shape, self.test_shape)
self.assertEqual(tensor.qdata.dtype, torch.int8)
self.assertTrue(
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127)
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128), # 2D
((32, 128), 64, 256), # 3D
],
)
def test_int8_linear_variants(
self,
dtype: torch.dtype,
config,
compile: bool,
sizes: tuple,
):
"""Test linear operation supports including shape and compile"""
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
model_q = copy.deepcopy(model)

quantize_(model_q, config)

if compile:
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
output_quantized = model_q(input_tensor)

assert compute_error(output_fp, output_quantized) > 20, (
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_per_row_scale_shape(self, dtype, config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test like this one

to test all the variations and include the shape check there?

"""Test per-row quantization maintains 1D scale"""
N, K = 64, 128
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")
quantize_(linear, config)

# Dynamic: per-row (1D scale [N]), Weight-only: per-tensor (scalar)
if isinstance(config, Int8DynamicActivationInt8WeightConfig):
self.assertEqual(linear.weight.scale.shape, (N,))
self.assertEqual(linear.weight.scale.ndim, 1)
else:
self.assertEqual(linear.weight.scale.numel(), 1)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_slice(self, config, device, dtype):
"""Test tensor slicing"""
tensor_size = 256
slice_sizes = (64, 128)

dummy = torch.nn.Linear(
tensor_size, tensor_size, bias=False, dtype=dtype, device=device
)
quantize_(dummy, config)

weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0])
weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1])

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))

# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
# Int8WeightOnlyConfig uses per-tensor (PerTensor)
if isinstance(config, Int8DynamicActivationInt8WeightConfig):
# PerRow: dim 0 slicing affects scale, dim 1 doesn't
self.assertEqual(
weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])
)
self.assertEqual(weight2.scale, dummy.weight.scale)
else:
# PerTensor: scale unchanged by slicing
self.assertEqual(weight1.scale, dummy.weight.scale)
self.assertEqual(weight2.scale, dummy.weight.scale)
with self.assertRaises(NotImplementedError):
_ = dummy.weight[::2]

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_index_select(self, config):
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
N, K = 256, 512
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
linear.weight.data = x
quantize_(linear, config)

x_int8 = linear.weight
x_int8_0 = x_int8[0]
torch.testing.assert_close(
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_dequantization_accuracy(self, config):
"""Test dequantization accuracy separately"""
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16, device="cuda")
linear = torch.nn.Linear(2, 1, bias=False, dtype=torch.bfloat16, device="cuda")
linear.weight.data = test_data
quantize_(linear, config)

tensor = linear.weight
dequantized = tensor.dequantize()
self.assertEqual(dequantized.shape, test_data.shape)
self.assertLess(
torch.abs(dequantized - test_data).max().item(),
0.1,
msg=f"Dequantization error exceeds tolerance of {0.1}",
)

def test_available_gpu_kernels(self):
"""Check which GPU kernels are available"""
M, K, N = 128, 256, 512
m = torch.nn.Sequential(
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
)
config = Int8DynamicActivationInt8WeightConfig(version=2)
quantize_(m, config)
m = torch.compile(m)
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

out, code = run_and_get_code(m, x)
has_triton = "triton" in code[0].lower() # Trition
has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM
has_int_mm = "_int_mm" in code[0] # Int8 MatMul

self.assertTrue(
has_triton or has_fbgemm or has_int_mm,
f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}",
)


if __name__ == "__main__":
common_utils.run_tests()
25 changes: 18 additions & 7 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,18 @@ def _slice_scale_for_dimension(
"""
aten = torch.ops.aten

# Unsupported case for now, this would be 1 scale per data element
# Per-tensor quantization (scalar scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is updated to support more granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe it's better to move this util function to a common place?

if scale.numel() == 1:
return scale

# Per-row quantization (1D scale)
if scale.ndim == 1:
if dim == 0:
return aten.slice.Tensor(scale, 0, start, end, step)
else:
return scale

# Block-wise quantization (2D scale)
if scale.shape == data_shape:
return aten.slice.Tensor(scale, dim, start, end, step)

Expand All @@ -158,6 +169,12 @@ def _slice_scale_for_dimension(
# Slice away as normal
return aten.slice.Tensor(scale, dim, start, end, step)
else:
# Error on Step > 1
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)

# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
Expand All @@ -167,12 +184,6 @@ def _slice_scale_for_dimension(
else None
)

# Error on Step > 1
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)

return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)


Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -168,6 +169,7 @@
"IntxOpaqueTensor",
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
"Int8Tensor",
"Float8Tensor",
"Int4OpaqueTensor",
# smooth quant - subject to change
Expand Down
Loading