diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index abf938c322..e347dfd2e3 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -26,7 +26,6 @@ Layouts and Tensor Subclasses MarlinQQQTensor MarlinQQQLayout Int4CPULayout - CutlassInt4PackedLayout CutlassSemiSparseLayout Quantization techniques @@ -52,6 +51,7 @@ Prototype :nosignatures: BlockSparseLayout + CutlassInt4PackedLayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index dc58470526..2d05426d73 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1946,5 +1946,32 @@ def test_benchmark_model_cpu(self): assert self.run_benchmark_model("cpu") is not None +# TODO: Remove this test once the deprecated API has been removed +def test_cutlass_int4_packed_layout_deprecated(): + import sys + import warnings + + # We need to clear the cache to force re-importing and trigger the warning again. + modules_to_clear = [ + "torchao.dtypes.uintx.cutlass_int4_packed_layout", + "torchao.dtypes", + ] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 + + warnings.simplefilter("always") # Ensure all warnings are captured + assert any( + issubclass(warning.category, DeprecationWarning) + and "CutlassInt4PackedLayout" in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" + ) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index b1e7fc9875..252498bc97 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,6 @@ ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - CutlassInt4PackedLayout, Int4CPULayout, Int4XPULayout, Int8DynamicActInt4WeightCPULayout, @@ -29,6 +28,7 @@ to_marlinqqq_quantized_intx, ) from .uintx.block_sparse_layout import BlockSparseLayout +from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 2b6f47e692..e46809059e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,12 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( - _linear_int4_act_int4_weight_cutlass_check, - _linear_int4_act_int4_weight_cutlass_impl, - _linear_int8_act_int4_weight_cutlass_check, - _linear_int8_act_int4_weight_cutlass_impl, -) from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, @@ -94,6 +88,12 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) +from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + _linear_int8_act_int4_weight_cutlass_check, + _linear_int8_act_int4_weight_cutlass_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d269fc4c4..b76e80e0fc 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,3 @@ -from .cutlass_int4_packed_layout import ( - CutlassInt4PackedLayout, -) from .dyn_int8_act_int4_wei_cpu_layout import ( Int8DynamicActInt4WeightCPULayout, ) @@ -43,7 +40,6 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", - "CutlassInt4PackedLayout", "PackedLinearInt8DynamicActivationIntxWeightLayout", "QDQLayout", "Int4XPULayout", diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index d680f4cf77..582dff6d50 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -3,222 +3,24 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Optional -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - _aqt_is_int8, +warnings.warn( + "Importing from torchao.dtypes is deprecated. " + "Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. " + "This import path will be removed in a future torchao release. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout - -aten = torch.ops.aten - - -def _aqt_is_int4(aqt): - """Check if an AffineQuantizedTensor is int4 quantized Tensor""" - # TODO: use torch.int4 - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -8 - and aqt.quant_max == 7 - ) - - -def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool: - return ( - isinstance(self, Int4PackedTensorImpl) - and isinstance(src, Int4PackedTensorImpl) - and self.shape == src.shape - and self.int_data.shape == src.int_data.shape - and self.scale.shape == src.scale.shape - and type(self._layout) == type(src._layout) - ) - - -@dataclass(frozen=True) -class CutlassInt4PackedLayout(Layout): - """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" - - pass - - -@register_layout(CutlassInt4PackedLayout) -class Int4PackedTensorImpl(AQTTensorImpl): - """ - TensorImpl storage class for int4 packed layout for affine quantized tensor. - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - self.int_data = int_data - self.scale = scale - self._layout = _layout - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - elif func is aten.copy_.default: - self = args[0] - src = args[1] - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - raise NotImplementedError( - f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "scale"], [self._layout] - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - scale = tensor_data_dict["scale"] - (_layout,) = tensor_attributes - return cls(int_data, scale, _layout) - - def get_plain(self): - int_data = torch.stack( - ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 - ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) - return int_data, self.scale, None - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert zero_point is None or torch.all(zero_point == 0) - int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) - return cls( - int_data_s4, - scale, - _layout, - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.scale = fn(self.scale) - return self - - -def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and isinstance(input_tensor._layout, PlainLayout) - and _aqt_is_int8(input_tensor) - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) - and _aqt_is_int4(weight_tensor) - and weight_tensor.dtype == input_tensor.dtype - and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == torch.float32 - and len(weight_tensor.tensor_impl.scale.shape) == 1 - and (bias is None or bias.dtype == input_tensor.dtype) - and (bias is None or len(bias.shape) == 1) - ) - - -def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 - - weight = weight_tensor.tensor_impl.int_data - weight_scale = weight_tensor.tensor_impl.scale - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - out_dtype = input_tensor.dtype - - out = rowwise_scaled_linear_cutlass_s8s4( - input, input_scale, weight, weight_scale, bias, out_dtype - ) - - return out - - -def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and isinstance(input_tensor._layout, CutlassInt4PackedLayout) - and _aqt_is_int4(input_tensor) - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and len(input_tensor.shape) >= 2 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) - and _aqt_is_int4(weight_tensor) - and weight_tensor.dtype == input_tensor.dtype - and len(weight_tensor.shape) == 2 - and weight_tensor.tensor_impl.scale.dtype == torch.float32 - and len(weight_tensor.tensor_impl.scale.shape) == 1 - ) - - -def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 - - weight = weight_tensor.tensor_impl.int_data - weight_scale = weight_tensor.tensor_impl.scale - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - out_dtype = input_tensor.dtype - - out = rowwise_scaled_linear_cutlass_s4s4( - input, input_scale, weight, weight_scale, bias, out_dtype - ) - - return out +from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( # noqa: F401 + CutlassInt4PackedLayout, # noqa: F401 + Int4PackedTensorImpl, # noqa: F401 + _linear_int4_act_int4_weight_cutlass_check, # noqa: F401 + _linear_int4_act_int4_weight_cutlass_impl, # noqa: F401 + _linear_int8_act_int4_weight_cutlass_check, # noqa: F401 + _linear_int8_act_int4_weight_cutlass_impl, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 54d395e673..25f139d583 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from .uintx import BlockSparseLayout +from .uintx import BlockSparseLayout, CutlassInt4PackedLayout __all__ = [ "BlockSparseLayout", + "CutlassInt4PackedLayout", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 107e6a344b..53edddb8ac 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. from .block_sparse_layout import BlockSparseLayout +from .cutlass_int4_packed_layout import CutlassInt4PackedLayout __all__ = [ "BlockSparseLayout", + "CutlassInt4PackedLayout", ] diff --git a/torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py new file mode 100644 index 0000000000..d680f4cf77 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/cutlass_int4_packed_layout.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout + +aten = torch.ops.aten + + +def _aqt_is_int4(aqt): + """Check if an AffineQuantizedTensor is int4 quantized Tensor""" + # TODO: use torch.int4 + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -8 + and aqt.quant_max == 7 + ) + + +def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool: + return ( + isinstance(self, Int4PackedTensorImpl) + and isinstance(src, Int4PackedTensorImpl) + and self.shape == src.shape + and self.int_data.shape == src.int_data.shape + and self.scale.shape == src.scale.shape + and type(self._layout) == type(src._layout) + ) + + +@dataclass(frozen=True) +class CutlassInt4PackedLayout(Layout): + """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" + + pass + + +@register_layout(CutlassInt4PackedLayout) +class Int4PackedTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for int4 packed layout for affine quantized tensor. + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self._layout = _layout + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + elif func is aten.copy_.default: + self = args[0] + src = args[1] + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + raise NotImplementedError( + f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "scale"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + (_layout,) = tensor_attributes + return cls(int_data, scale, _layout) + + def get_plain(self): + int_data = torch.stack( + ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 + ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) + return int_data, self.scale, None + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert zero_point is None or torch.all(zero_point == 0) + int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) + return cls( + int_data_s4, + scale, + _layout, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + return self + + +def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, PlainLayout) + and _aqt_is_int8(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == torch.float32 + and len(weight_tensor.tensor_impl.scale.shape) == 1 + and (bias is None or bias.dtype == input_tensor.dtype) + and (bias is None or len(bias.shape) == 1) + ) + + +def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_cutlass_s8s4( + input, input_scale, weight, weight_scale, bias, out_dtype + ) + + return out + + +def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and isinstance(input_tensor._layout, CutlassInt4PackedLayout) + and _aqt_is_int4(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == torch.float32 + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + +def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + out_dtype = input_tensor.dtype + + out = rowwise_scaled_linear_cutlass_s4s4( + input, input_scale, weight, weight_scale, bias, out_dtype + ) + + return out