diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 6cbec7465e..abf938c322 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -22,7 +22,6 @@ Layouts and Tensor Subclasses FloatxTensor FloatxTensorCoreLayout MarlinSparseLayout - BlockSparseLayout UintxLayout MarlinQQQTensor MarlinQQQLayout @@ -43,6 +42,17 @@ Quantization techniques to_affine_quantized_floatx_static to_marlinqqq_quantized_intx to_nf4 + +Prototype +--------- +.. currentmodule:: torchao.prototype.dtypes + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + BlockSparseLayout + .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 003a50c4d1..c9d41a98a9 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -253,7 +253,7 @@ def test_sparse(self, compile): quantize_(model_copy, Int8DynamicActivationInt8WeightConfig()) reference = model_copy(input) - from torchao.dtypes import BlockSparseLayout + from torchao.prototype.dtypes import BlockSparseLayout quantize_( model, @@ -267,6 +267,33 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) + # TODO: Remove this test once the deprecated API has been removed + def test_sparse_deprecated(self): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.block_sparse_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import BlockSparseLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + self.assertTrue( + any( + issubclass(warning.category, DeprecationWarning) + and "BlockSparseLayout" in str(warning.message) + for warning in w + ), + f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", + ) + common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 07f03c7ed9..b1e7fc9875 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,6 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, @@ -29,6 +28,7 @@ UintxLayout, to_marlinqqq_quantized_intx, ) +from .uintx.block_sparse_layout import BlockSparseLayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ffadece729..2b6f47e692 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.block_sparse_layout import ( - _linear_int8_act_int8_weight_block_sparse_check, - _linear_int8_act_int8_weight_block_sparse_impl, -) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( _linear_int4_act_int4_weight_cutlass_check, _linear_int4_act_int4_weight_cutlass_impl, @@ -94,6 +90,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 6d1bc95653..1d269fc4c4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .block_sparse_layout import ( - BlockSparseLayout, -) from .cutlass_int4_packed_layout import ( CutlassInt4PackedLayout, ) @@ -39,7 +36,6 @@ __all__ = [ "UintxLayout", - "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 0c6046c313..6ca4e8745a 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -3,231 +3,22 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import logging -from dataclasses import dataclass -from typing import Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - PlainAQTTensorImpl, - _aqt_is_int8_reduced_range, -) -from torchao.dtypes.utils import ( - Layout, - PlainLayout, +warnings.warn( + "Importing BlockSparseLayout from torchao.dtypes is deprecated. " + "Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. " + "This import path will be removed in a future torchao release. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", + DeprecationWarning, + stacklevel=2, ) -logger = logging.getLogger(__name__) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class BlockSparseLayout(Layout): - """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. - - Attributes: - blocksize (int): The size of the blocks in the sparse matrix. Default is 64. - """ - - blocksize: int = 64 - - -@register_layout(BlockSparseLayout) -class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): - bsr_crow_indices: Optional[torch.Tensor] - bsr_col_indices: Optional[torch.Tensor] - bsr_values: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - zero_point: Optional[torch.Tensor] - - __slots__ = [ - "bsr_crow_indices", - "bsr_col_indices", - "bsr_values", - "scale", - "zero_point", - ] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - if bsr_values is None: - raise ValueError("bsr values must be provided!") - else: - previous_tensor = bsr_values - - kwargs = { - "device": previous_tensor.device, - "dtype": previous_tensor.dtype, - "layout": previous_tensor.layout, - "requires_grad": requires_grad, - } - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( # noqa: PYI034 - self, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - self.bsr_crow_indices = bsr_crow_indices - self.bsr_col_indices = bsr_col_indices - self.bsr_values = bsr_values - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self._layout, self.requires_grad) - return inner_tensors, tensor_meta - - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta: Tuple[torch.Size, bool], - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, _layout, requires_grad = tensor_meta - return cls( - shape=shape, - bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), - bsr_col_indices=inner_tensors.get("bsr_col_indices", None), - bsr_values=inner_tensors.get("bsr_values", None), - scale=inner_tensors.get("scale", None), - zero_point=inner_tensors.get("zero_point", None), - _layout=_layout, - requires_grad=requires_grad, - ) - - @classmethod - def from_plain(cls, int_data, scale, zero_point, _layout): - bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) - return cls( - shape=int_data.shape, - bsr_crow_indices=bsr_tensor.crow_indices(), - bsr_col_indices=bsr_tensor.col_indices(), - bsr_values=bsr_tensor.values(), - scale=scale, - zero_point=zero_point, - _layout=_layout, - requires_grad=False, - ) - - def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense( - self.crow_indices(), - self.col_indices(), - self.values(), - self.shape[0], - self.shape[1], - ) - return int_data_expanded, self.scale, self.zero_point - - def _apply_fn_to_data(self, func): - return self.__class__( - shape=self.shape, - bsr_crow_indices=func(self.bsr_crow_indices), - bsr_col_indices=func(self.bsr_col_indices), - bsr_values=func(self.bsr_values), - scale=self.scale, - zero_point=self.zero_point, - _layout=self._layout, - requires_grad=self.requires_grad, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - # Need the following for bsr specific functions - if func is aten.crow_indices.default: - return args[0].bsr_crow_indices.detach() - - if func is aten.col_indices.default: - return args[0].bsr_col_indices.detach() - - if func is aten.values.default: - return args[0].bsr_values.detach() - - if func is aten._nnz.default: - return args[0].bsr_values.shape[0] - - raise NotImplementedError( - f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - -def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.is_cuda - and input_tensor.dtype == weight_tensor.dtype - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor._layout, BlockSparseLayout) - ) - - -def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals = weight_tensor.tensor_impl - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - tmp_t = tmp.t() - - y = torch.ops.blocksparse.int_addmm( - w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1), - ) - y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) - y = y.reshape(*y_shape) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y +from torchao.prototype.dtypes.uintx.block_sparse_layout import ( + BlockSparseAQTTensorImpl, # noqa: F401 + BlockSparseLayout, # noqa: F401 + _linear_int8_act_int8_weight_block_sparse_check, # noqa: F401 + _linear_int8_act_int8_weight_block_sparse_impl, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py new file mode 100644 index 0000000000..54d395e673 --- /dev/null +++ b/torchao/prototype/dtypes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .uintx import BlockSparseLayout + +__all__ = [ + "BlockSparseLayout", +] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py new file mode 100644 index 0000000000..107e6a344b --- /dev/null +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .block_sparse_layout import BlockSparseLayout + +__all__ = [ + "BlockSparseLayout", +] diff --git a/torchao/prototype/dtypes/uintx/block_sparse_layout.py b/torchao/prototype/dtypes/uintx/block_sparse_layout.py new file mode 100644 index 0000000000..0c6046c313 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/block_sparse_layout.py @@ -0,0 +1,233 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import ( + Layout, + PlainLayout, +) + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class BlockSparseLayout(Layout): + """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. + + Attributes: + blocksize (int): The size of the blocks in the sparse matrix. Default is 64. + """ + + blocksize: int = 64 + + +@register_layout(BlockSparseLayout) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self._layout, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, _layout, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + _layout=_layout, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, _layout): + bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + _layout=_layout, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape=self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + _layout=self._layout, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y