Skip to content

Commit 23688a9

Browse files
metascroyfacebook-github-bot
authored andcommitted
Clean up linear_int8_dynamic_activation_intx_weight_subclass (#1553)
Summary: Pull Request resolved: #1553 Cleans up layout and quantization API: ``` int8_dynamic_activation_intx_weight( group_size: int = 128, bit_width: int = 4, has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ) ``` int8_dynamic_activation_intx_weight is now very similar to int8_dynamic_activation_int4_weight. By passing bit_width=4, has_weight_zeros=false, and layout=PlainLayout(), it should be numerically identical (but slower). The fallback option is removed and instead relies on using PlainLayout(). Reviewed By: jerryzh168 Differential Revision: D67821939
1 parent ad61822 commit 23688a9

8 files changed

+478
-485
lines changed

torchao/_models/llama/generate.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -546,29 +546,18 @@ def ffn_or_attn_only(mod, fqn):
546546

547547
assert (
548548
precision == torch.float32
549-
), "int8_dynamic_activation_intx_weight requires fp32 precision"
550-
551-
try:
552-
torch.ops.torchao._pack_8bit_act_4bit_weight
553-
except:
554-
print(
555-
"Unable to load experimental torchao kernels. Performance will be slow."
556-
)
557-
print(
558-
"To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU"
559-
)
549+
), "int8_dynamic_activation_intx_weight requires using precision=torch.float32"
560550

561551
# Quantize model
562552
_quant_args = quantization.split("-")
563-
nbit = int(_quant_args[1])
564-
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
553+
bit_width = int(_quant_args[1])
565554
group_size = int(_quant_args[2])
566555
has_weight_zeros = bool(_quant_args[3])
567556
quantize_(
568557
model,
569558
int8_dynamic_activation_intx_weight(
559+
bit_width=bit_width,
570560
group_size=group_size,
571-
nbit=nbit,
572561
has_weight_zeros=has_weight_zeros,
573562
),
574563
)

torchao/dtypes/uintx/plain_layout.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __new__(
3838
cls,
3939
int_data: torch.Tensor,
4040
scale: torch.Tensor,
41-
zero_point: torch.Tensor,
41+
zero_point: Optional[torch.Tensor],
4242
_layout: Layout,
4343
):
4444
kwargs = {}
@@ -55,7 +55,7 @@ def __init__(
5555
self,
5656
int_data: torch.Tensor,
5757
scale: torch.Tensor,
58-
zero_point: torch.Tensor,
58+
zero_point: Optional[torch.Tensor],
5959
_layout: Layout,
6060
):
6161
self.int_data = int_data
@@ -64,6 +64,8 @@ def __init__(
6464
self._layout = _layout
6565

6666
def __tensor_flatten__(self):
67+
if self.zero_point is None:
68+
return ["int_data", "scale"], [self._layout]
6769
return ["int_data", "scale", "zero_point"], [self._layout]
6870

6971
@classmethod
@@ -73,7 +75,7 @@ def __tensor_unflatten__(
7375
int_data, scale, zero_point = (
7476
tensor_data_dict["int_data"],
7577
tensor_data_dict["scale"],
76-
tensor_data_dict["zero_point"],
78+
tensor_data_dict.get("zero_point", None),
7779
)
7880
(_layout,) = tensor_attributes
7981
return cls(int_data, scale, zero_point, _layout)
@@ -83,15 +85,17 @@ def to(self, *args, **kwargs):
8385
return self.__class__(
8486
self.int_data.to(kwargs["device"]),
8587
self.scale.to(kwargs["device"]),
86-
self.zero_point.to(kwargs["device"]),
88+
self.zero_point.to(kwargs["device"])
89+
if self.zero_point is not None
90+
else None,
8791
self._layout,
8892
)
8993

9094
def _apply_fn_to_data(self, fn):
9195
return self.__class__(
9296
fn(self.int_data),
9397
fn(self.scale),
94-
fn(self.zero_point),
98+
fn(self.zero_point) if self.zero_point is not None else None,
9599
self._layout,
96100
)
97101

@@ -134,7 +138,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
134138
return PlainAQTTensorImpl(
135139
aten.slice.Tensor(self.int_data, dim, start, end, step),
136140
self.scale.view(-1),
137-
self.zero_point.view(-1),
141+
self.zero_point.view(-1) if self.zero_point is not None else None,
138142
self._layout,
139143
)
140144
else:
@@ -148,7 +152,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
148152

149153
__torch_function__ = torch._C._disabled_torch_function_impl
150154

151-
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
155+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
152156
return self.int_data, self.scale, self.zero_point
153157

154158
def get_layout(self) -> Layout:

torchao/dtypes/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Tuple, Union
2+
from typing import Optional, Tuple, Union
33

44
import torch
55

@@ -87,7 +87,7 @@ class AQTTensorImpl(TorchAOBaseTensor):
8787
the underlying implementation of a AQT based on layout
8888
"""
8989

90-
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90+
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
9191
"""Get the plain (unpacked) Tensor for the tensor impl
9292
9393
Returns data, scale and zero_point
@@ -103,7 +103,7 @@ def from_plain(
103103
cls,
104104
data: torch.Tensor,
105105
scale: torch.Tensor,
106-
zero_point: torch.Tensor,
106+
zero_point: Optional[torch.Tensor],
107107
_layout: Layout,
108108
):
109109
"""Construct a TensorImpl from data, scale, zero_point and the _layout"""

0 commit comments

Comments
 (0)