diff --git a/ruff.toml b/ruff.toml index 9545a9be96..04c9e32cca 100644 --- a/ruff.toml +++ b/ruff.toml @@ -2,8 +2,6 @@ # We plan to add files in chunks using the 'include' list below. # To add a new path: Simply add it to the 'include' list. # Example: To lint all files in every subfolder of 'test', add "test/**/*" -# To exclude a file type: Simply add it to the 'include' list. -# Example: To lint all files in every subfolder of 'test', add "test/**/*" include = [ "torchao/float8/inference.py", "torchao/float8/float8_utils.py", @@ -12,9 +10,4 @@ include = [ "torchao/float8/float8_tensor.py", "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", - "torchao/dtypes/*" -] - -exclude = [ - "**/*.md" ] diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8d65fafaca..e27bf6497a 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,5 +1,4 @@ from .nf4tensor import NF4Tensor, to_nf4 - # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor from .affine_quantized_tensor import ( @@ -22,7 +21,7 @@ __all__ = [ "NF4Tensor", "to_nf4", - "UInt4Tensor", + "UInt4Tensor" "AffineQuantizedTensor", "to_affine_quantized_intx", "to_affine_quantized_intx_static", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 19decc4dbb..c6a3730859 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,6 +1,7 @@ import torch from typing import Tuple, Optional, Union -import torchao.ops +from collections import defaultdict +import functools import math from torchao.quantization.quant_primitives import ( choose_qparams_affine, @@ -31,13 +32,13 @@ find_multiple, TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, - _is_float8_type, + _is_float8_type ) import logging -from torchao.float8.inference import Float8MMConfig logger = logging.getLogger(__name__) +from torchao.float8.inference import Float8MMConfig aten = torch.ops.aten @@ -48,7 +49,6 @@ class AQTLayout(TorchAOBaseTensor): """ Base class for the layout tensor for `AffineQuantizedTensor` """ - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the layout Tensor @@ -68,7 +68,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - """Construct a Layout from data, scale, zero_point and the layout_type""" + """ Construct a Layout from data, scale, zero_point and the layout_type""" pass def __repr__(self): @@ -83,14 +83,11 @@ def __repr__(self): class QuantizedLinearNotImplementedError(NotImplementedError): - """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" - + """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ pass _AQT_QLINEAR_DISPATCH_TABLE = {} - - def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """Register a dispatch for quantized linear op with dispatch_condition function and impl function both takes three arguments: @@ -107,15 +104,11 @@ def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """ _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl - def deregister_aqt_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] else: - logger.warn( - f"Attempting to remove non-existant dispatch condition {dispatch_condition}" - ) - + logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") class AffineQuantizedTensor(TorchAOBaseTensor): """ @@ -162,9 +155,7 @@ def __new__( kwargs = {} kwargs["device"] = layout_tensor.device kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else layout_tensor.layout + kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout ) kwargs["dtype"] = dtype if strides is not None: @@ -203,16 +194,9 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype = self.dtype from torchao.dtypes.fpx import FpxTensorCoreLayoutType - if isinstance(self.layout_type, FpxTensorCoreLayoutType): int_data, scale = self.layout_tensor.get_plain() - return dequantize_affine_fpx( - int_data, - scale, - self.layout_type.ebits, - self.layout_type.mbits, - output_dtype=output_dtype, - ) + return dequantize_affine_fpx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) else: data, scale, zero_point = self.layout_tensor.get_plain() return dequantize_affine( @@ -232,28 +216,17 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError( - "No specialized dispatch found for quantized linear op" - ) + raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") def __tensor_flatten__(self): - return ["layout_tensor"], [ - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - self.dtype, - ] + return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): layout_tensor = tensor_data_dict["layout_tensor"] - block_size, shape, quant_min, quant_max, zero_point_domain, dtype = ( - tensor_attributes - ) + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( layout_tensor, block_size, @@ -286,58 +259,20 @@ def from_hp_to_intx( input_float = layout_type.pre_process(input_float) if use_hqq: - assert ( - zero_point_domain == ZeroPointDomain.FLOAT - and mapping_type == MappingType.ASYMMETRIC - and quant_min == 0 - ), "Invalid input parameters for HQQ quantization." + assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization." nbits = int(math.log2(quant_max + 1)) - axis = 1 if (block_size[0] == 1) else 0 + axis = 1 if (block_size[0]==1) else 0 group_size = max(block_size) - compute_dtype = ( - zero_point_dtype - if (zero_point_dtype is not None) - else input_float.dtype - ) + compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype device = input_float.device - data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( - input_float, - nbits=nbits, - group_size=group_size, - axis=axis, - compute_dtype=compute_dtype, - device=device, - verbose=False, - raw_output=False, - ) + data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False) data = data.to(target_dtype) else: - scale, zero_point = choose_qparams_affine( - input_float, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - preserve_zero, - zero_point_domain, - ) + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None if zero_point_domain is None: zero_point = None - data = quantize_affine( - input_float, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - zero_point_domain, - ) + data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) # Note: output will be uint8 tensor for sub byte tensors for now data = layout_type.post_process(data) @@ -350,7 +285,7 @@ def from_hp_to_intx( quant_min, quant_max, zero_point_domain, - dtype=input_float.dtype, + dtype=input_float.dtype ) @classmethod @@ -367,25 +302,12 @@ def from_hp_to_intx_static( layout_type: LayoutType = PlainLayoutType(), ): if target_dtype not in FP8_TYPES: - assert ( - zero_point_domain is not None - ), "zero_point_domain must be specified for non-fp8 types" - assert ( - zero_point is not None - ), "zero_point must be specified for non-fp8 types" + assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types" + assert zero_point is not None, "zero_point must be specified for non-fp8 types" original_shape = input_float.shape input_float = layout_type.pre_process(input_float) - int_data = quantize_affine( - input_float, - block_size, - scale, - zero_point, - target_dtype, - quant_min, - quant_max, - zero_point_domain, - ) + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) int_data = layout_type.post_process(int_data) @@ -410,6 +332,7 @@ def from_hp_to_floatx( scale_dtype: Optional[torch.dtype], layout_type: LayoutType, ): + if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -427,9 +350,7 @@ def from_hp_to_floatx( use_hqq=False, ) else: - raise NotImplementedError( - f"Unsupported dtype {target_dtype} for from_hp_to_floatx" - ) + raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx") @classmethod def from_hp_to_floatx_static( @@ -440,6 +361,7 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, layout_type: LayoutType, ): + if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -453,9 +375,7 @@ def from_hp_to_floatx_static( layout_type=layout_type, ) else: - raise NotImplementedError( - f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" - ) + raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static") @classmethod def from_hp_to_fpx( @@ -464,10 +384,7 @@ def from_hp_to_fpx( layout_type: LayoutType, ): from torchao.dtypes.fpx import FpxTensorCoreLayoutType - - assert isinstance( - layout_type, FpxTensorCoreLayoutType - ), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}" + assert isinstance(layout_type, FpxTensorCoreLayoutType), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}" original_shape = input_float.shape input_float = layout_type.pre_process(input_float) # per axis quantization, where axis = 1 @@ -482,7 +399,12 @@ def from_hp_to_fpx( layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) layout_tensor = layout_tensor_ctr(fpx_packed, scale, None, layout_type) - return cls(layout_tensor, block_size, original_shape, dtype=input_float.dtype) + return cls( + layout_tensor, + block_size, + original_shape, + dtype=input_float.dtype + ) @property def layout_type(self) -> LayoutType: @@ -534,9 +456,9 @@ def _apply_fn_to_data(self, fn): register_layout_cls = AffineQuantizedTensor.register_layout_cls get_layout_tensor_constructor = AffineQuantizedTensor.get_layout_tensor_constructor - @dataclass(frozen=True) class SemiSparseLayoutType(LayoutType): + def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() @@ -570,10 +492,11 @@ class Float8LayoutType(LayoutType): @dataclass(frozen=True) class MarlinSparseLayoutType(LayoutType): + def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format - - 2º: tensor is injected with 2:4 sparsity + - 2º: tensor is injected with 2:4 sparsity - 3º: transposes it again because the quantization process will compute the scales for dim=-1 Args: @@ -583,7 +506,6 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: the preprocessed tensor """ from torchao.sparsity.marlin import inject_24 # avoid circular import - input_t = input.t() w_24, _ = inject_24(input_t, *input_t.shape) return w_24.t() @@ -600,7 +522,6 @@ class PlainAQTLayout(AQTLayout): scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor """ - def __new__( cls, int_data: torch.Tensor, @@ -637,12 +558,8 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - int_data, scale, zero_point = ( - tensor_data_dict["int_data"], - tensor_data_dict["scale"], - tensor_data_dict["zero_point"], - ) - (layout_type,) = tensor_attributes + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + layout_type, = tensor_attributes return cls(int_data, scale, zero_point, layout_type) def to(self, *args, **kwargs): @@ -679,10 +596,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.t.default: tensor = args[0] new = tensor.__class__( - tensor.int_data.view(tensor.shape[::-1]), - tensor.scale, - tensor.zero_point, - tensor.layout_type, + tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -709,13 +623,11 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) - @register_layout_cls(SemiSparseLayoutType) class SemiSparseAQTLayout(PlainAQTLayout): """ Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor """ - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -733,10 +645,10 @@ def get_plain(self): # Currently we don't have cuSPARSELt expansion routines, so we matmul by # the identity matrix to get the original dense matrix. This is slow though. cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm( - self.int_data, - torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), - ) + int_data_expanded = torch._cslt_sparse_mm(self.int_data, + torch.eye(cols, + dtype=self.int_data.dtype, + device=self.int_data.device).t()) return int_data_expanded, self.scale, self.zero_point @classmethod @@ -755,8 +667,8 @@ def from_plain( @register_layout_cls(MarlinSparseLayoutType) class MarlinSparseAQTLayout(AQTLayout): """ - Layout storage class for sparse_marlin_24 layout for affine quantized tensor. - + Layout storage class for sparse_marlin_24 layout for affine quantized tensor. + Can be used with 4 bits and 8 bits quantization. Original marlin documentation and information: @@ -770,7 +682,6 @@ class MarlinSparseAQTLayout(AQTLayout): group_size (int): the group size used to pack the tensor num_bits (int): the number of bits used to quantize the tensor """ - @staticmethod def __new__( cls, @@ -827,12 +738,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [ - self.layout_type, - self.original_shape, - self.group_size, - self.num_bits, - ] + return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits] @classmethod def __tensor_unflatten__( @@ -843,26 +749,14 @@ def __tensor_unflatten__( zero_point = tensor_data_dict["zero_point"] meta = tensor_data_dict["meta"] layout_type, original_shape, group_size, num_bits = tensor_attributes - return cls( - int_data, - scale, - zero_point, - meta, - layout_type, - original_shape, - group_size, - num_bits, - ) + return cls(int_data, scale, zero_point, meta, layout_type, original_shape, group_size, num_bits) def get_plain(self): - from torchao.sparsity.marlin import ( - unpack_from_marlin_24, - ) # avoid circular import - + from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import int_data_expanded, scales_expanded = unpack_from_marlin_24( - self.int_data, - self.scale, - self.meta, + self.int_data, + self.scale, + self.meta, self.original_shape, self.group_size, self.num_bits, @@ -879,11 +773,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - from torchao.sparsity.marlin import ( - pack_to_marlin_24, - const, - ) # avoid circular import - + from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import assert isinstance(layout_type, MarlinSparseLayoutType) # Linear layers are (in_features, out_features) but the int_data that is reaching this point @@ -893,12 +783,12 @@ def from_plain( if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( - f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' ) if q_w_24.dtype != torch.int32: raise ValueError("Only `torch.int32` weights are supported.") - + in_features, out_features = q_w_24.shape if in_features % 128 != 0 or out_features != 256 == 0: raise ValueError( @@ -910,14 +800,14 @@ def from_plain( # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main num_bits = 4 if torch.max(q_w_24) < 16 else -1 if num_bits not in [4]: - raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + raise ValueError( + f"Only {[4]} bits are supported, got {num_bits}." + ) group_size = in_features // scale_t.shape[0] if group_size == 0: group_size = in_features - assert ( - group_size <= in_features - ), "Group size must be less than or equal to in_features." + assert group_size <= in_features, "Group size must be less than or equal to in_features." if group_size not in const.SUPPORTED_GROUP_SIZES: raise ValueError( @@ -925,21 +815,14 @@ def from_plain( ) # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( - q_w_24, scale_t, num_bits, group_size - ) + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) return cls( - marlin_24_q_w_comp, - marlin_24_s, - zero_point, - meta, - layout_type, - q_w_24.shape, - group_size, - num_bits, + marlin_24_q_w_comp, marlin_24_s, zero_point, + meta, layout_type, q_w_24.shape, + group_size, num_bits ) - + def get_layout_type(self) -> LayoutType: return self.layout_type @@ -956,7 +839,6 @@ class Float8AQTLayout(AQTLayout): """ Layout storage class for float8 layout for affine quantized tensor """ - float8_data: torch.Tensor scale: torch.Tensor transposed: bool @@ -991,7 +873,7 @@ def __init__( self.layout_type = layout_type def _apply_fn_to_data(self, fn): - """Applys a fn to all tensor components stored on this class""" + """ Applys a fn to all tensor components stored on this class""" fn(self.float8_data) fn(self.scale) return self @@ -1013,10 +895,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - ( - transposed, - layout_type, - ) = tensor_attributes + transposed, layout_type, = tensor_attributes return cls(float8_data, scale, transposed, layout_type) @classmethod @@ -1058,26 +937,20 @@ def from_plain( zero_point: Optional[torch.Tensor], layout_type: LayoutType, ): - """Main entrypoint for constructing Float8Layout Tensor""" - assert _is_float8_type( - data.dtype - ), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" - assert isinstance( - layout_type, Float8LayoutType - ), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" + """ Main entrypoint for constructing Float8Layout Tensor""" + assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): float8_data, scale, _ = self.get_plain() layout_type = self.get_layout_type() - return ( - f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"layout_type={layout_type})" - ) - + return (f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"layout_type={layout_type})") + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): @@ -1101,9 +974,7 @@ def __new__( kwargs = {} kwargs["device"] = packed_weight.device kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout + kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout ) kwargs["dtype"] = packed_weight.dtype kwargs["requires_grad"] = False @@ -1129,14 +1000,8 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_weight, scale_and_zero = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scale_and_zero"], - ) - ( - transposed, - layout_type, - ) = tensor_attributes + packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] + transposed, layout_type, = tensor_attributes return cls(packed_weight, scale_and_zero, transposed, layout_type) @classmethod @@ -1145,24 +1010,20 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType, + layout_type: LayoutType ): + assert isinstance(layout_type, TensorCoreTiledLayoutType) if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert ( - int_data.dtype == torch.uint8 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" else: - assert ( - int_data.dtype == torch.int32 - ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, layout_type.inner_k_tiles - ) + assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) return cls(packed_weight, scale_and_zero, False, layout_type) @@ -1170,9 +1031,7 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] if not is_device("cuda", device): - raise ValueError( - f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}" - ) + raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -1218,7 +1077,6 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quantize_affine, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros - scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -1235,26 +1093,12 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm( - torch.eye(eye_shape, device=device, dtype=original_dtype), - self.packed_weight, - groupsize, - self.scale_and_zero, - ) + dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) dequantized = dequantized.t().contiguous() # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine( - dequantized, - block_size, - scale, - zero, - target_dtype, - quant_min, - quant_max, - zero_point_domain, - ) + int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) return int_data, scale, zero def get_layout_type(self) -> LayoutType: @@ -1265,36 +1109,28 @@ def get_layout_type(self) -> LayoutType: # torch functional and aten operator implementation # ##################################################### - def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.layout_tensor.dtype == torch.int8 - and aqt.quant_min is None - or aqt.quant_min == -128 - and aqt.quant_max is None - or aqt.quant_max == 127 + aqt.layout_tensor.dtype == torch.int8 and + aqt.quant_min is None or aqt.quant_min == -128 and + aqt.quant_max is None or aqt.quant_max == 127 ) - def _aqt_is_int8_reduced_range(aqt): return ( - aqt.layout_tensor.dtype == torch.int8 - and aqt.quant_min == -127 - and aqt.quant_max is None - or aqt.quant_max == 127 + aqt.layout_tensor.dtype == torch.int8 and + aqt.quant_min == -127 and + aqt.quant_max is None or aqt.quant_max == 127 ) - def _aqt_is_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.layout_tensor.dtype == torch.int32 - and aqt.quant_min is None - or aqt.quant_min == 0 - and aqt.quant_max is None - or aqt.quant_max == 15 + aqt.layout_tensor.dtype == torch.int32 and + aqt.quant_min is None or aqt.quant_min == 0 and + aqt.quant_max is None or aqt.quant_max == 15 ) @@ -1306,19 +1142,17 @@ def _aqt_is_uint4(aqt): # bias: dimension is (out_features,) # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches - def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor.layout_type, PlainLayoutType) - and isinstance(weight_tensor.layout_type, PlainLayoutType) + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.is_cuda and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_tensor.layout_type, PlainLayoutType) ) - def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): # # 1. do the matrix form of dot(X_i, W_j) @@ -1350,23 +1184,18 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y -def _linear_int8_act_int8_weight_semi_structured_sparse_check( - input_tensor, weight_tensor, bias -): +def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor.layout_type, PlainLayoutType) - and isinstance(weight_tensor.layout_type, SemiSparseLayoutType) + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.is_cuda and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_tensor.layout_type, SemiSparseLayoutType) ) - -def _linear_int8_act_int8_weight_semi_structured_sparse_impl( - input_tensor, weight_tensor, bias -): +def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): x_vals_int8 = input_tensor.layout_tensor.int_data x_scales = input_tensor.layout_tensor.scale w_vals_int8 = weight_tensor.layout_tensor.int_data @@ -1374,10 +1203,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, - tmp.t(), - alpha=w_scales.to(torch.float32), - out_dtype=torch.bfloat16, + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] @@ -1389,27 +1215,23 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( y += bias return y - def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.dtype == torch.bfloat16 - and + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.dtype == torch.bfloat16 and # weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_uint4(weight_tensor) - and weight_tensor.dtype == torch.bfloat16 - and len(weight_tensor.shape) == 2 - and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT - and isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_uint4(weight_tensor) and + weight_tensor.dtype == torch.bfloat16 and + len(weight_tensor.shape) == 2 and + weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and + isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) ) def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert ( - weight_tensor.block_size[0] == 1 - ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" assert input_tensor.shape[-1] == weight_tensor.shape[1], ( f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " @@ -1433,9 +1255,7 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] @@ -1450,21 +1270,19 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_int8(weight_tensor) - and len(weight_tensor.shape) == 2 - and len(weight_tensor.block_size) == 2 - and weight_tensor.block_size[0] == 1 - and weight_tensor.block_size[1] == weight_tensor.shape[1] - and weight_tensor.zero_point_domain == ZeroPointDomain.INT - and isinstance(weight_tensor.layout_type, PlainLayoutType) + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_int8(weight_tensor) and + len(weight_tensor.shape) == 2 and + len(weight_tensor.block_size) == 2 and + weight_tensor.block_size[0] == 1 and + weight_tensor.block_size[1] == weight_tensor.shape[1] and + weight_tensor.zero_point_domain == ZeroPointDomain.INT and + isinstance(weight_tensor.layout_type, PlainLayoutType) ) - def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # TODO: enable cpu and mps efficient path # is_cpu and is_mps only, some issue with is_contiguous() currently @@ -1473,6 +1291,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # per channel int8 weight only quantizated mm w_vals_int8_t = weight_tensor.layout_tensor.int_data.t() scale = weight_tensor.layout_tensor.scale + orig_dtype = input_tensor.dtype m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8_t.to(input_tensor.dtype), @@ -1483,43 +1302,30 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y - def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.fpx import FpxTensorCoreLayoutType - return ( # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and input_tensor.dtype == torch.float16 - and + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and + input_tensor.dtype == torch.float16 and # weight is fpx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) - and ( + isinstance(weight_tensor, AffineQuantizedTensor) and + isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) and + ( # weight is using fp6 quantization - ( - weight_tensor.layout_type.ebits == 3 - and weight_tensor.layout_type.mbits == 2 - ) - or ( - weight_tensor.layout_type.ebits == 2 - and weight_tensor.layout_type.mbits == 3 - ) - or + (weight_tensor.layout_type.ebits == 3 and + weight_tensor.layout_type.mbits == 2) or + (weight_tensor.layout_type.ebits == 2 and + weight_tensor.layout_type.mbits == 3) or # weight is using fp5 quantization - ( - weight_tensor.layout_type.ebits == 2 - and weight_tensor.layout_type.mbits == 2 - ) - or ( - weight_tensor.layout_type.ebits == 3 - and weight_tensor.layout_type.mbits == 1 - ) + (weight_tensor.layout_type.ebits == 2 and + weight_tensor.layout_type.mbits == 2) or + (weight_tensor.layout_type.ebits == 3 and + weight_tensor.layout_type.mbits == 1) ) ) - def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.fpx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear @@ -1548,7 +1354,6 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) - def _linear_fp_act_fp8_tensor_wise_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], @@ -1556,12 +1361,11 @@ def _linear_fp_act_fp8_tensor_wise_weight_check( ) -> bool: def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( - isinstance(aqt, AffineQuantizedTensor) - and isinstance(aqt.layout_type, Float8LayoutType) + isinstance(aqt, AffineQuantizedTensor) and + isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and aqt.shape == aqt.block_size ) - return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) @@ -1578,9 +1382,7 @@ def _linear_fp_act_fp8_weight_impl( ) scaled_mm_config = weight_tensor.layout_type.mm_config - scaled_mm_config = ( - scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() - ) + scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() w_layout = weight_tensor.layout_tensor w_data = weight_tensor.layout_tensor.float8_data @@ -1612,17 +1414,17 @@ def _linear_fp_act_fp8_weight_impl( def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_uint4(weight_tensor) - and input_tensor.dtype == torch.float16 - and len(weight_tensor.shape) == 2 - and weight_tensor.zero_point_domain == ZeroPointDomain.INT - and isinstance(weight_tensor.layout_type, MarlinSparseLayoutType) + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_uint4(weight_tensor) and + input_tensor.dtype == torch.float16 and + len(weight_tensor.shape) == 2 and + weight_tensor.zero_point_domain == ZeroPointDomain.INT and + isinstance(weight_tensor.layout_type, MarlinSparseLayoutType) ) - def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace + from torchao.sparsity.marlin import marlin_24_workspace, const + from torchao.ops import marlin_24_gemm assert isinstance(weight_tensor, AffineQuantizedTensor) @@ -1640,16 +1442,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b size_k = input_2d.shape[1] workspace_24 = marlin_24_workspace(original_shape[1]) - out = torchao.ops.marlin_24_gemm( - input_2d, - sparse_w_int4, - meta, - scale, - workspace_24, - num_bits, - size_m, - size_n, - size_k, + out = marlin_24_gemm( + input_2d, sparse_w_int4, meta, scale, + workspace_24, num_bits, size_m, size_n, size_k ) # Unfold the batch dimension @@ -1663,25 +1458,17 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - ( - _linear_int8_act_int8_weight_semi_structured_sparse_check, - _linear_int8_act_int8_weight_semi_structured_sparse_impl, - ), + (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), - ( - _linear_fp_act_int4_weight_sparse_marlin_check, - _linear_fp_act_int4_weight_sparse_marlin_impl, - ), + (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) - _register_aqt_quantized_linear_dispatches() - @implements(torch.nn.functional.linear) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -1690,9 +1477,7 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) + raise NotImplementedError(f"{func} is not implemented for non floating point input") # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -1701,11 +1486,7 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if ( - isinstance(weight_tensor, AffineQuantizedTensor) - and hasattr(weight_tensor.layout_type, "quantized_linear_impl") - and weight_tensor.layout_type.quantized_linear_impl is not None - ): + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1714,7 +1495,6 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - @implements(aten.addmm.default) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -1723,9 +1503,7 @@ def _(func, types, args, kwargs): args[0], ) if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) + raise NotImplementedError(f"{func} is not implemented for non floating point input") # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -1735,11 +1513,7 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if ( - isinstance(weight_tensor, AffineQuantizedTensor) - and hasattr(weight_tensor.layout_type, "quantized_linear_impl") - and weight_tensor.layout_type.quantized_linear_impl is not None - ): + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1748,25 +1522,22 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return func(bias, input_tensor, weight_tensor) - @implements(aten.mm.default) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = (args[0], args[1], None) + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + None + ) if not input_tensor.is_floating_point(): - raise NotImplementedError( - f"{func} is not implemented for non floating point input" - ) + raise NotImplementedError(f"{func} is not implemented for non floating point input") try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if ( - isinstance(weight_tensor, AffineQuantizedTensor) - and hasattr(weight_tensor.layout_type, "quantized_linear_impl") - and weight_tensor.layout_type.quantized_linear_impl is not None - ): + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1775,7 +1546,6 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return func(input_tensor, weight_tensor) - @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -1799,7 +1569,6 @@ def _(func, types, args, kwargs): args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) - @implements(aten.t.default) def _(func, types, args, kwargs): block_size = args[0].block_size @@ -1808,18 +1577,10 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.layout_tensor.t(), - transposed_block_size, - shape, - tensor.quant_min, - tensor.quant_max, - tensor.zero_point_domain, - dtype=tensor.dtype, - strides=tensor.stride(), + tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() ) return return_and_correct_aliasing(func, args, kwargs, new) - to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/dtypes/fpx/__init__.py b/torchao/dtypes/fpx/__init__.py index a62eb48283..af77685fac 100644 --- a/torchao/dtypes/fpx/__init__.py +++ b/torchao/dtypes/fpx/__init__.py @@ -1,15 +1 @@ -from .fpx import ( - FpxTensorCoreLayoutType, - FpxTensorCoreAQTLayout, - to_scaled_tc_fpx, - from_scaled_tc_fpx, - _SPLIT_K_MAP, -) - -__all__ = [ - "FpxTensorCoreAQTLayout", - "FpxTensorCoreLayoutType", - "to_scaled_tc_fpx", - "from_scaled_tc_fpx", - "_SPLIT_K_MAP", -] +from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/dtypes/fpx/fpx.py index 77064baefc..6afa22f560 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/dtypes/fpx/fpx.py @@ -4,14 +4,11 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.prototype.custom_fp_utils import ( - _f32_to_fpx_unpacked, - _fpx_unpacked_to_f32, - _n_ones, -) +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.dtypes.utils import ( LayoutType, ) +from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls @@ -21,23 +18,11 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce( - torch.bitwise_or, - [ - x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) - for i in range(8 // n_bits) - ], - ) + return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack( - [ - (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) - for i in range(8 // n_bits) - ], - dim=-1, - ).flatten(-2) + return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -51,40 +36,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [ - 1, - 5, - 9, - 13, - 17, - 21, - 25, - 29, - 3, - 7, - 11, - 15, - 19, - 23, - 27, - 31, - 0, - 4, - 8, - 12, - 16, - 20, - 24, - 28, - 2, - 6, - 10, - 14, - 18, - 22, - 26, - 30, - ], + 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, + 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -93,40 +46,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [ - 16, - 0, - 24, - 8, - 17, - 1, - 25, - 9, - 18, - 2, - 26, - 10, - 19, - 3, - 27, - 11, - 20, - 4, - 28, - 12, - 21, - 5, - 29, - 13, - 22, - 6, - 30, - 14, - 23, - 7, - 31, - 15, - ], + 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, + 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -162,12 +83,8 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = ( - tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) - ) # Pass 2 from original code - tensor_ybit = _bit_interleave( - tensor_ybit.flatten(), y - ) # Pass 3 from original code + tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code + tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -209,9 +126,7 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( - _ONES_TABLE[mbits + 1] / (2**mbits) - ) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal @@ -237,10 +152,8 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = ( - tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) - ) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -311,7 +224,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 5, 14336: 7, 28672: 7, - 57344: 7, + 57344: 7 }, { # tokens: [65:128] 3072: 9, @@ -322,7 +235,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 5, 14336: 7, 28672: 7, - 57344: 6, + 57344: 6 }, { # tokens: [129:192] 3072: 6, @@ -333,7 +246,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 5, 14336: 5, 28672: 5, - 57344: 4, + 57344: 4 }, { # tokens: [193:256] 3072: 9, @@ -344,7 +257,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 4, 14336: 8, 28672: 6, - 57344: 4, + 57344: 4 }, { # tokens: [257:320] 3072: 7, @@ -355,7 +268,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 1, 14336: 3, 28672: 3, - 57344: 4, + 57344: 4 }, { # tokens: [321:384] 3072: 3, @@ -366,7 +279,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 8, 14336: 3, 28672: 4, - 57344: 3, + 57344: 3 }, { # tokens: [385:448] 3072: 5, @@ -377,7 +290,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 3, 14336: 1, 28672: 1, - 57344: 3, + 57344: 3 }, { # tokens: [449:512] 3072: 2, @@ -388,7 +301,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 2, 14336: 6, 28672: 4, - 57344: 1, + 57344: 1 }, { # tokens: [513:576] 3072: 2, @@ -399,7 +312,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 3, 14336: 3, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [577:640] 3072: 5, @@ -410,7 +323,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 1, 14336: 1, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [641:704] 3072: 3, @@ -421,7 +334,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 2, 14336: 1, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [705:768] 3072: 3, @@ -432,22 +345,20 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 1, 14336: 1, 28672: 1, - 57344: 1, - }, + 57344: 1 + } ] # quantization api integrations - @dataclass(frozen=True) class FpxTensorCoreLayoutType(LayoutType): - """Layout type for FpxTensorCoreAQTLayout""" - + """Layout type for FpxTensorCoreAQTLayout + """ ebits: int mbits: int - @register_layout_cls(FpxTensorCoreLayoutType) class FpxTensorCoreAQTLayout(AQTLayout): """FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b), @@ -471,7 +382,6 @@ class FpxTensorCoreAQTLayout(AQTLayout): it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit) """ - def __new__( cls, packed_fpx_data: torch.Tensor, @@ -480,16 +390,11 @@ def __new__( ): assert packed_fpx_data.ndim == 2 assert packed_fpx_data.dtype == torch.uint8 - shape = ( - packed_fpx_data.shape[0], - packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8, - ) + shape = (packed_fpx_data.shape[0], packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8) kwargs = {} kwargs["device"] = packed_fpx_data.device kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_fpx_data.layout + kwargs.get("layout") if kwargs.get("layout", False) else packed_fpx_data.layout ) kwargs["dtype"] = packed_fpx_data.dtype kwargs["requires_grad"] = False @@ -512,17 +417,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_fpx_data, scale = ( - tensor_data_dict["packed_fpx_data"], - tensor_data_dict["scale"], - ) - (layout_type,) = tensor_attributes + packed_fpx_data, scale = tensor_data_dict["packed_fpx_data"], tensor_data_dict["scale"] + layout_type, = tensor_attributes return cls(packed_fpx_data, scale, layout_type) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_fpx_data = unpack_tc_fpx( - self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits - ) + unpacked_fpx_data = unpack_tc_fpx(self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits) return unpacked_fpx_data, self.scale @classmethod @@ -541,9 +441,7 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(layout_type, FpxTensorCoreLayoutType) - packed_fpx_data = pack_tc_fpx( - unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits - ) + packed_fpx_data = pack_tc_fpx(unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits) return cls(packed_fpx_data, scale, layout_type) def __repr__(self): @@ -581,12 +479,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: x.to(device=kwargs.pop("device", None)) - ), + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), ) raise NotImplementedError( diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uint4.py index 14bd0bd3ae..fc6eb2646c 100644 --- a/torchao/dtypes/uint4.py +++ b/torchao/dtypes/uint4.py @@ -105,6 +105,7 @@ def __new__(cls, elem, **kwargs): ) def __init__(self, elem, **kwargs): + self.elem = elem @classmethod diff --git a/torchao/dtypes/uintx/Uintx.py b/torchao/dtypes/uintx/Uintx.py index 157f40f8da..cfe75f4dc7 100644 --- a/torchao/dtypes/uintx/Uintx.py +++ b/torchao/dtypes/uintx/Uintx.py @@ -43,7 +43,6 @@ class UintxTensor(TorchAOBaseTensor): bit_width (int): number of bits for each element pack_dim: (int) dimension to pack along """ - bits_to_shard = { 1: ["int1_shard"], 2: ["int2_shard"], @@ -53,7 +52,6 @@ class UintxTensor(TorchAOBaseTensor): 6: ["int4_shard", "int2_shard"], 7: ["int4_shard", "int2_shard", "int1_shard"], } - def __new__( cls, shards: List[torch.Tensor], @@ -83,28 +81,24 @@ def __init__( self.pack_dim = pack_dim def get_shards(self): - return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] + return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_width]] def __repr__(self): return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)})" def __tensor_flatten__(self): - return self.__class__.bits_to_shard[self.bit_width], [ - self.packed_shape, - self.bit_width, - self.pack_dim, - ] + return self.__class__.bits_to_shard[self.bit_width], [self.packed_shape, self.bit_width, self.pack_dim] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - shards = list(tensor_data_dict.values()) + shards = list(tensor_data_dict.values()) packed_shape, bit_width, pack_dim = tensor_attributes return cls(shards, packed_shape, bit_width, pack_dim) def get_plain(self): - return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) + return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim) # temporary until kernels on packed tensors are created def apply_transformation(self, fn): @@ -116,21 +110,18 @@ def apply_transformation(self, fn): # temporary until kernels on packed tensors are created def apply_fn_to_shards(self, fn): new_shards = [fn(shard) for shard in self.get_shards()] - return self.__class__( - new_shards, self.packed_shape, self.bit_width, self.pack_dim - ) + return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim) @classmethod def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): - assert ( - dtype in _DTYPE_TO_BIT_WIDTH.keys() - ), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" + assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" bit_width = _DTYPE_TO_BIT_WIDTH[dtype] shards = pack(int_data, bit_width, dim=pack_dim) shape = list(int_data.shape) shape[pack_dim] = shape[pack_dim] * bit_width // 8 return cls(shards, int_data.shape, bit_width, pack_dim) + def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -159,8 +150,8 @@ def to(self, *args, **kwargs): return super().to(*args, **kwargs) -implements = UintxTensor.implements +implements = UintxTensor.implements @implements(aten.detach.default) def _(func, types, args, kwargs): @@ -168,43 +159,33 @@ def _(func, types, args, kwargs): func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) ) - @implements(aten.view.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) ) - @implements(aten._to_copy.default) def _(func, types, args, kwargs): - return return_and_correct_aliasing(func, args, kwargs, args[0]) - + return return_and_correct_aliasing( + func, args, kwargs, args[0] + ) @implements(aten.sub.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), + func, args, kwargs, args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)) ) - @implements(aten.mul.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), + func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)) ) - # quantization api integrations to_uintx = UintxTensor.from_uint8 - @dataclass(frozen=True) class UintxLayoutType(LayoutType): dtype: torch.dtype @@ -213,9 +194,9 @@ class UintxLayoutType(LayoutType): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) - @register_layout_cls(UintxLayoutType) class UintxAQTLayout(PlainAQTLayout): + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/uintx/bitpacking.py b/torchao/dtypes/uintx/bitpacking.py index 5e0331b72c..244ca437ef 100644 --- a/torchao/dtypes/uintx/bitpacking.py +++ b/torchao/dtypes/uintx/bitpacking.py @@ -7,16 +7,16 @@ 1: (0x01,), 2: (0x03,), 3: (0x03, 0x04), - 4: (0x0F,), - 5: (0x0F, 0x10), - 6: (0x0F, 0x30), - 7: (0x0F, 0x30, 0x40), + 4: (0x0f,), + 5: (0x0f, 0x10), + 6: (0x0f, 0x30), + 7: (0x0f, 0x30, 0x40), } unpack_mask = { - 1: (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80), - 2: (0x03, 0x0C, 0x30, 0xC0), - 4: (0x0F, 0xF0), + 1: (0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80), + 2: (0x03,0x0c,0x30,0xc0), + 4: (0x0f,0xf0), } # size of each shard @@ -41,7 +41,6 @@ 7: (0, 4, 6), } - # for shifting groups left but right if shift is negative def abs_lsh(data, shift): if shift == 0: @@ -62,9 +61,9 @@ def abs_rsh(data, shift): return data >> shift -def pack_cpu( - data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 -) -> List[torch.Tensor]: +def pack_cpu(data: torch.Tensor, + elem_size: int, + dim: Optional[int] = -1) -> List[torch.Tensor]: """ Inputs: data: a tensor of sub byte elements in uint8 @@ -112,10 +111,7 @@ def pack_cpu( After pack, data went from 8 elements to 6: [[0, 105, 151, 37], [39, 146]] In general this means pack reduces input tensor size from n * 8 to n * elem_size """ - torch._assert( - data.shape[dim] % 8 == 0, - f"pack dimension size ({data.shape[dim]}) is not divisble by scale", - ) + torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") torch._assert(data.dtype == torch.uint8, "data must be uint8") output_shape = list(data.shape) @@ -135,9 +131,9 @@ def pack_cpu( return output -def unpack_cpu( - data: List[torch.Tensor], elem_size: int, dim: Optional[int] = -1 -) -> torch.Tensor: +def unpack_cpu(data: List[torch.Tensor], + elem_size: int, + dim: Optional[int] = -1) -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. @@ -164,37 +160,30 @@ def unpack_cpu( output_narrow = output.narrow(dim, j * group_size, group_size) group = data[i] & unpack_mask[bit_size][j] shift_amt = j * bit_size - rel_pos - output_narrow.copy_( - torch.bitwise_or(output_narrow, abs_rsh(group, shift_amt)) - ) + output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos))) return output - # these are faster on the GPU - def _pack(data, elem_size, scale, dim): - """ + ''' Inner for loop from above pack function - """ + ''' packed_shape = list(data.shape) packed_shape[dim] = packed_shape[dim] // scale packed = torch.zeros(packed_shape, dtype=data.dtype, device=data.device) for i in range(scale): - narrow_slice = data.narrow( - dim, data.shape[dim] * i // scale, data.shape[dim] // scale - ) + narrow_slice = data.narrow(dim, data.shape[dim]*i//scale, data.shape[dim] // scale) packed |= narrow_slice << (elem_size * i) return packed - def _unpack(data, element_size, scale, dim): - """ + ''' Inner for loop from above unpack function - """ + ''' unpacked_shape = list(data.shape) unpacked_shape[dim] *= scale @@ -204,57 +193,30 @@ def _unpack(data, element_size, scale, dim): for i in range(scale): shift_amt = element_size * i - unpacked_data.narrow( - dim, - unpacked_data.shape[dim] * i // scale, - unpacked_data.shape[dim] // scale, - ).copy_((data >> shift_amt) & nbits) + chunk = unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits) return unpacked_data -def pack( - data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 -) -> List[torch.Tensor]: - """ +def pack(data: torch.Tensor, + elem_size: int, + dim: Optional[int] = -1) -> List[torch.Tensor]: + ''' a less branching but more compute version so better for gpu - """ - torch._assert( - data.shape[dim] % 8 == 0, - f"pack dimension size ({data.shape[dim]}) is not divisble by scale", - ) + ''' + torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") torch._assert(data.dtype == torch.uint8, "data must be uint8") container_size = 8 - shards = [ - (data & maskbits[elem_size][i]) >> shifts[elem_size][i] - for i in range(len(maskbits[elem_size])) - ] - return tuple( - [ - _pack( - shards[i], - numbits[elem_size][i], - container_size // numbits[elem_size][i], - dim, - ) - for i in range(len(maskbits[elem_size])) - ] - ) - - -def unpack( - data: List[torch.Tensor], elem_size: int, dim: Optional[int] = 0 -) -> torch.Tensor: - """ + shards = [(data & maskbits[elem_size][i]) >> shifts[elem_size][i] for i in range(len(maskbits[elem_size]))] + return tuple([_pack(shards[i], numbits[elem_size][i], container_size//numbits[elem_size][i], dim) for i in range(len(maskbits[elem_size]))]) + +def unpack(data: List[torch.Tensor], + elem_size: int, + dim: Optional[int] = 0) -> torch.Tensor: + ''' a less branching but more compute version so better for gpu - """ + ''' container_size = 8 # unpack each 4,2,1 bit shard and unshift them back to the correct position - data = [ - _unpack( - data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim - ) - << shifts[elem_size][i] - for i in range(len(data)) - ] + data = [_unpack(data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim) << shifts[elem_size][i] for i in range(len(data))] return reduce(torch.bitwise_or, data) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 2407393fb9..7771bc34c5 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -11,8 +11,6 @@ layout interacts with different operators, e.g. the same data representation can have different behaviors when running the same operator, e.g. transpose, quantized_linear. """ - - @dataclass(frozen=True) class LayoutType: def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -27,21 +25,16 @@ def __repr__(self): def extra_repr(self) -> str: return "" - """ Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default """ - - @dataclass(frozen=True) class PlainLayoutType(LayoutType): pass - def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str - def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: """Returns the unflattened shape of the input tensor. Args: