Skip to content

Commit 0552dcf

Browse files
metascroyfacebook-github-bot
authored andcommitted
Clean up linear_int8_dynamic_activation_intx_weight_subclass (#1553)
Summary: Pull Request resolved: #1553 Cleans up layout and quantization API: ``` int8_dynamic_activation_intx_weight( group_size: int = 128, bit_width: int = 4, has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ) ``` int8_dynamic_activation_intx_weight is now very similar to int8_dynamic_activation_int4_weight. By passing bit_width=4, has_weight_zeros=false, and layout=PlainLayout(), it should be numerically identical (but slower). The fallback option is removed and instead relies on using PlainLayout(). Reviewed By: jerryzh168 Differential Revision: D67821939
1 parent 12a58cf commit 0552dcf

8 files changed

+481
-487
lines changed

torchao/_models/llama/generate.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -543,32 +543,22 @@ def ffn_or_attn_only(mod, fqn):
543543
from torchao.experimental.quant_api import (
544544
int8_dynamic_activation_intx_weight,
545545
)
546+
from torchao.quantization.granularity import PerGroup
546547

547548
assert (
548549
precision == torch.float32
549-
), "int8_dynamic_activation_intx_weight requires fp32 precision"
550-
551-
try:
552-
torch.ops.torchao._pack_8bit_act_4bit_weight
553-
except:
554-
print(
555-
"Unable to load experimental torchao kernels. Performance will be slow."
556-
)
557-
print(
558-
"To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU"
559-
)
550+
), "int8_dynamic_activation_intx_weight requires using precision=torch.float32"
560551

561552
# Quantize model
562553
_quant_args = quantization.split("-")
563-
nbit = int(_quant_args[1])
564-
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
565-
group_size = int(_quant_args[2])
554+
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
555+
granularity = PerGroup(int(_quant_args[2]))
566556
has_weight_zeros = bool(_quant_args[3])
567557
quantize_(
568558
model,
569559
int8_dynamic_activation_intx_weight(
570-
group_size=group_size,
571-
nbit=nbit,
560+
weight_dtype=weight_dtype,
561+
granularity=granularity,
572562
has_weight_zeros=has_weight_zeros,
573563
),
574564
)

torchao/dtypes/uintx/plain_layout.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __new__(
3838
cls,
3939
int_data: torch.Tensor,
4040
scale: torch.Tensor,
41-
zero_point: torch.Tensor,
41+
zero_point: Optional[torch.Tensor],
4242
_layout: Layout,
4343
):
4444
kwargs = {}
@@ -55,7 +55,7 @@ def __init__(
5555
self,
5656
int_data: torch.Tensor,
5757
scale: torch.Tensor,
58-
zero_point: torch.Tensor,
58+
zero_point: Optional[torch.Tensor],
5959
_layout: Layout,
6060
):
6161
self.int_data = int_data
@@ -64,6 +64,8 @@ def __init__(
6464
self._layout = _layout
6565

6666
def __tensor_flatten__(self):
67+
if self.zero_point is None:
68+
return ["int_data", "scale"], [self._layout]
6769
return ["int_data", "scale", "zero_point"], [self._layout]
6870

6971
@classmethod
@@ -73,7 +75,7 @@ def __tensor_unflatten__(
7375
int_data, scale, zero_point = (
7476
tensor_data_dict["int_data"],
7577
tensor_data_dict["scale"],
76-
tensor_data_dict["zero_point"],
78+
tensor_data_dict.get("zero_point", None),
7779
)
7880
(_layout,) = tensor_attributes
7981
return cls(int_data, scale, zero_point, _layout)
@@ -83,15 +85,17 @@ def to(self, *args, **kwargs):
8385
return self.__class__(
8486
self.int_data.to(kwargs["device"]),
8587
self.scale.to(kwargs["device"]),
86-
self.zero_point.to(kwargs["device"]),
88+
self.zero_point.to(kwargs["device"])
89+
if self.zero_point is not None
90+
else None,
8791
self._layout,
8892
)
8993

9094
def _apply_fn_to_data(self, fn):
9195
return self.__class__(
9296
fn(self.int_data),
9397
fn(self.scale),
94-
fn(self.zero_point),
98+
fn(self.zero_point) if self.zero_point is not None else None,
9599
self._layout,
96100
)
97101

@@ -134,7 +138,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
134138
return PlainAQTTensorImpl(
135139
aten.slice.Tensor(self.int_data, dim, start, end, step),
136140
self.scale.view(-1),
137-
self.zero_point.view(-1),
141+
self.zero_point.view(-1) if self.zero_point is not None else None,
138142
self._layout,
139143
)
140144
else:
@@ -148,7 +152,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
148152

149153
__torch_function__ = torch._C._disabled_torch_function_impl
150154

151-
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
155+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
152156
return self.int_data, self.scale, self.zero_point
153157

154158
def get_layout(self) -> Layout:

torchao/dtypes/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Tuple, Union
2+
from typing import Optional, Tuple, Union
33

44
import torch
55

@@ -87,7 +87,7 @@ class AQTTensorImpl(TorchAOBaseTensor):
8787
the underlying implementation of a AQT based on layout
8888
"""
8989

90-
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
9191
"""Get the plain (unpacked) Tensor for the tensor impl
9292
9393
Returns data, scale and zero_point
@@ -103,7 +103,7 @@ def from_plain(
103103
cls,
104104
data: torch.Tensor,
105105
scale: torch.Tensor,
106-
zero_point: torch.Tensor,
106+
zero_point: Optional[torch.Tensor],
107107
_layout: Layout,
108108
):
109109
"""Construct a TensorImpl from data, scale, zero_point and the _layout"""

0 commit comments

Comments
 (0)