diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b32868b684..c8774e9426 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -19,7 +19,9 @@ MultiTensorInputRecorder, ) from .granularity import ( + Granularity, PerAxis, + PerBlock, PerGroup, PerRow, PerTensor, @@ -197,8 +199,10 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "Granularity", "PerTensor", "PerAxis", + "PerBlock", "PerGroup", "PerRow", "PerToken", diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index 8cdae9caaa..6c7b582fe5 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -71,6 +71,7 @@ class PerGroup(Granularity): group_size: int +@dataclass(frozen=True) class PerRow(Granularity): """ Represents row-wise granularity in quantization. @@ -83,6 +84,7 @@ class PerRow(Granularity): pass +@dataclass(frozen=True) class PerToken(Granularity): """ Represents per-token granularity in quantization. @@ -99,3 +101,16 @@ class PerToken(Granularity): """ pass + + +@dataclass(frozen=True) +class PerBlock(Granularity): + """ + Represents per-block granularity in quantization. See + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for + `block_size` + Attributes: + block_size (tuple[int, ...]): The size of each quantization group + """ + + block_size: tuple[int, ...] diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 6d928a4477..d12ffaf520 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -14,7 +14,6 @@ from .granularity import ( Granularity, - PerAxis, PerRow, PerTensor, ) @@ -24,6 +23,7 @@ _get_reduction_params, choose_qparams_affine_with_min_max, ) +from .utils import get_block_size logger = logging.getLogger(__name__) @@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs): return r -def get_block_size( - input_shape: Tuple[int, ...], granularity: Granularity -) -> Tuple[int, ...]: - """Get the block size based on the input shape and granularity type. - - Args: - input_shape: The input tensor shape possibly more than 2 dimensions - granularity: The granularity type of the quantization - """ - if isinstance(granularity, PerTensor): - return input_shape - elif isinstance(granularity, PerAxis): - block_size = list(input_shape) - block_size[granularity.axis] = 1 - return tuple(block_size) - elif isinstance(granularity, PerRow): - return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - raise ValueError(f"Unsupported Granularity: {granularity}") - - ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: diff --git a/torchao/quantization/pt2e/__init__.py b/torchao/quantization/pt2e/__init__.py index 8b6a99337b..c7030023dc 100644 --- a/torchao/quantization/pt2e/__init__.py +++ b/torchao/quantization/pt2e/__init__.py @@ -5,6 +5,15 @@ import torch from torch import Tensor +from torchao.quantization import ( + Granularity, + PerAxis, + PerBlock, + PerGroup, + PerRow, + PerTensor, + PerToken, +) from torchao.quantization.pt2e._numeric_debugger import ( # noqa: F401 CUSTOM_KEY, FROM_NODE_KEY, @@ -32,6 +41,7 @@ get_equivalent_types, update_equivalent_types_dict, ) +from torchao.quantization.utils import get_block_size from .fake_quantize import ( FakeQuantize, @@ -48,7 +58,6 @@ from .observer import ( AffineQuantizedObserverBase, FixedQParamsObserver, - Granularity, HistogramObserver, MappingType, MinMaxObserver, @@ -57,20 +66,13 @@ NoopObserver, ObserverBase, PartialWrapper, - PerAxis, - PerBlock, PerChannelMinMaxObserver, - PerGroup, - PerRow, - PerTensor, - PerToken, PlaceholderObserver, RecordingObserver, ReuseInputObserver, TorchAODType, UniformQuantizationObserverBase, ZeroPointDomain, - get_block_size, ) for _f in [ diff --git a/torchao/quantization/pt2e/_affine_quantization.py b/torchao/quantization/pt2e/_affine_quantization.py index e02bee03ce..c206d15a79 100644 --- a/torchao/quantization/pt2e/_affine_quantization.py +++ b/torchao/quantization/pt2e/_affine_quantization.py @@ -13,14 +13,14 @@ import torch +from torchao.quantization import Granularity from torchao.quantization.pt2e.observer import ( AffineQuantizedObserverBase, - Granularity, MappingType, TorchAODType, ZeroPointDomain, - get_block_size, ) +from torchao.quantization.utils import get_block_size ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index a9e8c38439..f2263b8851 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -27,6 +27,15 @@ from torch.fx import Node import torchao +from torchao.quantization import ( + Granularity, + PerAxis, + PerBlock, + PerGroup, + PerRow, + PerTensor, + PerToken, +) from torchao.quantization.pt2e.utils import ( calculate_qmin_qmax, check_min_max_valid, @@ -34,6 +43,7 @@ is_per_tensor, validate_qmin_qmax, ) +from torchao.quantization.utils import get_block_size __all__ = [ "default_affine_fixed_qparams_observer", @@ -1622,7 +1632,6 @@ def calculate_qparams(self): We plan to merge the following with torchao repo after we move pt2e flow to torchao copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py """ -from dataclasses import dataclass from enum import Enum, auto @@ -1679,139 +1688,6 @@ class TorchAODType(Enum): INT7 = auto() -@dataclass(frozen=True) -class Granularity: - """ - Base class for representing the granularity of quantization. - - This class serves as a parent for specific granularity types used in - quantization operations, such as per-tensor or per-axis quantization. - """ - - -@dataclass(frozen=True) -class PerBlock(Granularity): - """ - Represents per-block granularity in quantization. See - :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for - `block_size` - - Attributes: - block_size (Tuple[int, ...]): The size of each quantization group - """ - - block_size: tuple[int, ...] - - -@dataclass(frozen=True) -class PerTensor(Granularity): - """ - Represents per-tensor granularity in quantization. - - This granularity type calculates the quantization parameters - based off the entire tensor. - - """ - - -@dataclass(frozen=True) -class PerAxis(Granularity): - """ - Represents per-axis granularity in quantization. - - This granularity type calculates different quantization parameters - along a specified axis of the tensor. - - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. - - Attributes: - axis (int): The axis along which reduction is performed. - """ - - axis: int - - -@dataclass(frozen=True) -class PerGroup(Granularity): - """ - Represents per-channel group granularity in quantization. - - This granularity type calculates different quantization parameters - for each group of elements. - - For example if the input tensor is shape [8, 16], and the group size is 4, then - the input tensor is reshaped to [64, 4] - quantization parameters are calculated for each group of 4 elements, - giving a total of 64 quantization parameters. - - Attributes: - group_size (int): The size of each quantization group - - """ - - group_size: int - - -class PerRow(Granularity): - """ - Represents row-wise granularity in quantization. - - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). - """ - - -class PerToken(Granularity): - """ - Represents per-token granularity in quantization. - - This granularity type calculates a different set of quantization parameters - for each token, which is represented as the last dimension of the tensor. - - For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens - with 4 elements each, and we will calculate 6 sets of quantization parameters, - one for each token. - - If the input tensor has only two dimensions, e.g. [8, 16], then this is - equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. - """ - - -def get_block_size( - input_shape: tuple[int, ...], granularity: Granularity -) -> tuple[int, ...]: - """Get the block size based on the input shape and granularity type. - - Args: - input_shape: The input tensor shape possibly more than 2 dimensions - granularity: The granularity type of the quantization - """ - assert isinstance(granularity, Granularity), ( - "Please provide an instance of Granularity, not subclass of it" - ) - if isinstance(granularity, PerTensor): - return input_shape - elif isinstance(granularity, PerAxis): - block_size = list(input_shape) - block_size[granularity.axis] = 1 - return tuple(block_size) - elif isinstance(granularity, PerRow): - return (1,) * (len(input_shape) - 1) + (input_shape[-1],) - elif isinstance(granularity, PerGroup): - assert len(input_shape) == 2, ( - f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" - ) - return (1, granularity.group_size) - elif isinstance(granularity, PerToken): - block_size = [1] * len(input_shape) - block_size[-1] = input_shape[-1] - return tuple(block_size) - raise ValueError(f"Unsupported Granularity: {granularity}") - - class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 09e3fa1e59..9c06264be8 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -14,7 +14,6 @@ PerRow, PerToken, ) -from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ( _DTYPE_TO_BIT_WIDTH, _DTYPE_TO_QVALUE_BOUNDS, @@ -28,6 +27,7 @@ ) from torchao.quantization.utils import ( _get_per_token_block_size, + get_block_size, get_group_qparams_symmetric, get_groupwise_affine_qparams, ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 021779f037..15caddcadc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -64,7 +64,7 @@ from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) -from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size +from torchao.quantization.observer import AffineQuantizedObserverBase from torchao.quantization.quantize_.common import ( KernelPreference, ) @@ -87,6 +87,7 @@ _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.quantization.utils import get_block_size from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, ) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 49c8b1cd24..bcae1fc756 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -23,7 +23,6 @@ preprocess_scale, ) from torchao.quantization.granularity import PerRow, PerTensor -from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ( _choose_scale_float8, _dequantize_affine_float8, @@ -34,6 +33,7 @@ QuantizeTensorKwargs, _choose_quant_func_and_quantize_tensor, ) +from torchao.quantization.utils import get_block_size from torchao.utils import ( TorchAOBaseTensor, _is_fbgemm_genai_gpu_available, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index d56fa0732d..b4b1a1087d 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -3,7 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from torch.utils._python_dispatch import TorchDispatchMode @@ -29,6 +29,16 @@ check_xpu_version, ) +from .granularity import ( + Granularity, + PerAxis, + PerBlock, + PerGroup, + PerRow, + PerTensor, + PerToken, +) + __all__ = [ "compute_error", "_quantize_activation_per_token_absmax", @@ -678,3 +688,37 @@ def recommended_inductor_config_setter(): torch._inductor.config.fx_graph_cache = True torch._inductor.config.triton.unique_kernel_names = True torch.set_float32_matmul_precision("high") + + +def get_block_size( + input_shape: Tuple[int, ...], granularity: Granularity +) -> Tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity: The granularity type of the quantization + """ + if isinstance(granularity, PerTensor): + return input_shape + elif isinstance(granularity, PerAxis): + block_size = list(input_shape) + block_size[granularity.axis] = 1 + return tuple(block_size) + elif isinstance(granularity, PerBlock): + block_size = granularity.block_size + assert len(block_size) == len(input_shape), ( + f"Block size {block_size} must have the same number of dimensions as input shape {input_shape}" + ) + for i in range(len(block_size)): + assert input_shape[i] % block_size[i] == 0, ( + f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}" + ) + return block_size + elif isinstance(granularity, (PerRow, PerToken)): + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert input_shape[-1] % granularity.group_size == 0, ( + f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}" + ) + return (1,) * (len(input_shape) - 1) + (granularity.group_size,) + raise ValueError(f"Unsupported Granularity: {granularity}")