Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bc67c1a
Move block_sparse_layout to prototype
jainapurva Nov 3, 2025
dd2e7d6
test fixes
jainapurva Nov 3, 2025
efa74cd
Remove unused import in __init__.py
jainapurva Nov 3, 2025
3722537
Clean up exports in block_sparse_layout.py
jainapurva Nov 3, 2025
6c1b8ef
test fixes
jainapurva Nov 3, 2025
1702bdf
Fix ruff import sorting and add noqa for re-exported imports
jainapurva Nov 3, 2025
18d682a
Move block_sparse_layout to prototype/dtypes/uintx and update all imp…
jainapurva Nov 3, 2025
de054ab
Apply ruff formatting (remove trailing newlines)
jainapurva Nov 3, 2025
4ef291f
Add Prototype section to dtypes documentation with BlockSparseLayout
jainapurva Nov 3, 2025
b492530
Move cutlass_int4_packed_layout to prototype/dtypes/uintx
jainapurva Nov 3, 2025
d699ee0
Clean up imports in affine_quantized_tensor_ops.py
jainapurva Nov 3, 2025
ab84799
Update internal links
jainapurva Nov 3, 2025
83795a6
<Replace this line with a title. Use 1 line only, 67 chars or less>
jainapurva Nov 3, 2025
40ab188
test fixes
jainapurva Nov 3, 2025
894857e
ruff fixes
jainapurva Nov 3, 2025
5b45869
Fixes
jainapurva Nov 3, 2025
1d51ebf
Remove unused import from uintx init file
jainapurva Nov 3, 2025
8c18d4d
Remove __all__ exports from module
jainapurva Nov 3, 2025
081a4ed
Empty commit to trigger CI
jainapurva Nov 3, 2025
61c1986
Lint fixes
jainapurva Nov 4, 2025
8ebbb9c
Add test cases
jainapurva Nov 4, 2025
b199f7b
Merge remote-tracking branch 'origin/move_block_sparsity' into move_c…
jainapurva Nov 4, 2025
78f5e4c
Add test cases
jainapurva Nov 4, 2025
ffcaca6
Merge remote-tracking branch 'origin/move_block_sparsity' into move_c…
jainapurva Nov 5, 2025
1a75689
Add test cases
jainapurva Nov 4, 2025
8222741
Merge remote-tracking branch 'origin/move_block_sparsity' into move_c…
jainapurva Nov 5, 2025
42acafa
Updates
jainapurva Nov 5, 2025
dc1e447
Merge branch 'main' into move_cutlass_int4_packed_layout
jainapurva Nov 5, 2025
40297ba
Update block_sparse_layout.py
jainapurva Nov 5, 2025
531b07b
Modify deprecation warning for import path
jainapurva Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Layouts and Tensor Subclasses
MarlinQQQTensor
MarlinQQQLayout
Int4CPULayout
CutlassInt4PackedLayout
CutlassSemiSparseLayout

Quantization techniques
Expand All @@ -52,6 +51,7 @@ Prototype
:nosignatures:

BlockSparseLayout
CutlassInt4PackedLayout

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
27 changes: 27 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,5 +1946,32 @@ def test_benchmark_model_cpu(self):
assert self.run_benchmark_model("cpu") is not None


# TODO: Remove this test once the deprecated API has been removed
def test_cutlass_int4_packed_layout_deprecated():
import sys
import warnings

# We need to clear the cache to force re-importing and trigger the warning again.
modules_to_clear = [
"torchao.dtypes.uintx.cutlass_int4_packed_layout",
"torchao.dtypes",
]
for mod in modules_to_clear:
if mod in sys.modules:
del sys.modules[mod]

with warnings.catch_warnings(record=True) as w:
from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401

warnings.simplefilter("always") # Ensure all warnings are captured
assert any(
issubclass(warning.category, DeprecationWarning)
and "CutlassInt4PackedLayout" in str(warning.message)
for warning in w
), (
f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}"
)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
CutlassInt4PackedLayout,
Int4CPULayout,
Int4XPULayout,
Int8DynamicActInt4WeightCPULayout,
Expand All @@ -29,6 +28,7 @@
to_marlinqqq_quantized_intx,
)
from .uintx.block_sparse_layout import BlockSparseLayout
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
from .utils import (
Layout,
PlainLayout,
Expand Down
12 changes: 6 additions & 6 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
_linear_f16_bf16_act_floatx_weight_check,
_linear_f16_bf16_act_floatx_weight_impl,
)
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
_linear_int4_act_int4_weight_cutlass_check,
_linear_int4_act_int4_weight_cutlass_impl,
_linear_int8_act_int4_weight_cutlass_check,
_linear_int8_act_int4_weight_cutlass_impl,
)
from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
_linear_int8_act_int4_weight_cpu_check,
_linear_int8_act_int4_weight_cpu_impl,
Expand Down Expand Up @@ -94,6 +88,12 @@
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import (
_linear_int4_act_int4_weight_cutlass_check,
_linear_int4_act_int4_weight_cutlass_impl,
_linear_int8_act_int4_weight_cutlass_check,
_linear_int8_act_int4_weight_cutlass_impl,
)
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
_dequantize_affine_no_zero_point,
Expand Down
4 changes: 0 additions & 4 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .cutlass_int4_packed_layout import (
CutlassInt4PackedLayout,
)
from .dyn_int8_act_int4_wei_cpu_layout import (
Int8DynamicActInt4WeightCPULayout,
)
Expand Down Expand Up @@ -43,7 +40,6 @@
"MarlinQQQLayout",
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
"CutlassInt4PackedLayout",
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"QDQLayout",
"Int4XPULayout",
Expand Down
232 changes: 17 additions & 215 deletions torchao/dtypes/uintx/cutlass_int4_packed_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,222 +3,24 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional

import torch
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)
# Backward compatibility stub - imports from the new location
import warnings

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
_aqt_is_int8,
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add #2752 to the message

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I'd remove torchao v0.16.0 and just say "in a future release of torchao", just in case the work gets delayed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still prefer adding a deprecation timeline, maybe we can add it in the issue description

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, add it to the issue! that way it can be modified without a code change.

"Importing from torchao.dtypes is deprecated. "
"Please use 'from torchao.prototype.dtypes import CutlassInt4PackedLayout' instead. "
"This import path will be removed in a future torchao release. "
"Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ",
DeprecationWarning,
stacklevel=2,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout

aten = torch.ops.aten


def _aqt_is_int4(aqt):
"""Check if an AffineQuantizedTensor is int4 quantized Tensor"""
# TODO: use torch.int4
return (
aqt.tensor_impl.dtype == torch.int8
and aqt.quant_min == -8
and aqt.quant_max == 7
)


def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
return (
isinstance(self, Int4PackedTensorImpl)
and isinstance(src, Int4PackedTensorImpl)
and self.shape == src.shape
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and type(self._layout) == type(src._layout)
)


@dataclass(frozen=True)
class CutlassInt4PackedLayout(Layout):
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""

pass


@register_layout(CutlassInt4PackedLayout)
class Int4PackedTensorImpl(AQTTensorImpl):
"""
TensorImpl storage class for int4 packed layout for affine quantized tensor.
"""

@staticmethod
def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
_layout: Layout,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
_layout: Layout,
):
self.int_data = int_data
self.scale = scale
self._layout = _layout

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
)

raise NotImplementedError(
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
)

def __tensor_flatten__(self):
return ["int_data", "scale"], [self._layout]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
(_layout,) = tensor_attributes
return cls(int_data, scale, _layout)

def get_plain(self):
int_data = torch.stack(
((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1
).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],))
return int_data, self.scale, None

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert zero_point is None or torch.all(zero_point == 0)
int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF)
return cls(
int_data_s4,
scale,
_layout,
)

def get_layout(self) -> Layout:
return self._layout

def _apply_fn_to_data(self, fn):
self.int_data = fn(self.int_data)
self.scale = fn(self.scale)
return self


def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and isinstance(input_tensor._layout, PlainLayout)
and _aqt_is_int8(input_tensor)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == torch.float32
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, CutlassInt4PackedLayout)
and _aqt_is_int4(weight_tensor)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == torch.float32
and len(weight_tensor.tensor_impl.scale.shape) == 1
and (bias is None or bias.dtype == input_tensor.dtype)
and (bias is None or len(bias.shape) == 1)
)


def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4

weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale
out_dtype = input_tensor.dtype

out = rowwise_scaled_linear_cutlass_s8s4(
input, input_scale, weight, weight_scale, bias, out_dtype
)

return out


def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and isinstance(input_tensor._layout, CutlassInt4PackedLayout)
and _aqt_is_int4(input_tensor)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == torch.float32
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, CutlassInt4PackedLayout)
and _aqt_is_int4(weight_tensor)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == torch.float32
and len(weight_tensor.tensor_impl.scale.shape) == 1
)


def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
from torchao.ops import rowwise_scaled_linear_cutlass_s4s4

weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale
out_dtype = input_tensor.dtype

out = rowwise_scaled_linear_cutlass_s4s4(
input, input_scale, weight, weight_scale, bias, out_dtype
)

return out
from torchao.prototype.dtypes.uintx.cutlass_int4_packed_layout import ( # noqa: F401
CutlassInt4PackedLayout, # noqa: F401
Int4PackedTensorImpl, # noqa: F401
_linear_int4_act_int4_weight_cutlass_check, # noqa: F401
_linear_int4_act_int4_weight_cutlass_impl, # noqa: F401
_linear_int8_act_int4_weight_cutlass_check, # noqa: F401
_linear_int8_act_int4_weight_cutlass_impl, # noqa: F401
)
3 changes: 2 additions & 1 deletion torchao/prototype/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from .uintx import BlockSparseLayout
from .uintx import BlockSparseLayout, CutlassInt4PackedLayout

__all__ = [
"BlockSparseLayout",
"CutlassInt4PackedLayout",
]
2 changes: 2 additions & 0 deletions torchao/prototype/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# LICENSE file in the root directory of this source tree.

from .block_sparse_layout import BlockSparseLayout
from .cutlass_int4_packed_layout import CutlassInt4PackedLayout

__all__ = [
"BlockSparseLayout",
"CutlassInt4PackedLayout",
]
Loading
Loading