Skip to content

Commit 358d6b4

Browse files
committed
[Feat]: Enable dyn_quant_pack_4bit aten kernels via Linear8BitActXBitWeightLayout
Signed-off-by: Nikhil Gupta <[email protected]>
1 parent 2a18e60 commit 358d6b4

File tree

3 files changed

+196
-11
lines changed

3 files changed

+196
-11
lines changed

torchao/experimental/_linear_8bit_act_xbit_weight_layout.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
MappingType,
2424
ZeroPointDomain,
2525
)
26+
from torchao.utils import (
27+
TORCH_VERSION_AT_LEAST_2_6,
28+
)
2629

2730
logger = logging.getLogger(__name__)
2831
logger.setLevel(logging.WARNING)
@@ -40,13 +43,16 @@ class Target(Enum):
4043

4144
NATIVE = auto()
4245
FALLBACK = auto()
46+
ATEN = auto()
4347

4448

4549
def target_from_str(target: str) -> Target:
4650
if target.lower() == "native":
4751
return Target.NATIVE
4852
elif target.lower() == "fallback":
4953
return Target.FALLBACK
54+
elif target.lower() == "aten":
55+
return Target.ATEN
5056
else:
5157
raise ValueError(f"Invalid target: {target}")
5258

@@ -56,22 +62,27 @@ class Linear8BitActXBitWeightLayout(Layout):
5662
nbit: int
5763
group_size: int
5864

59-
# The target platform for the layout, either 'native' or 'fallback'.
65+
# The target platform for the layout, 'native', 'fallback' or 'aten'
6066
target: Target
6167

68+
# Allow bias access via layout
69+
bias: Optional[torch.Tensor] = None
70+
6271
def __init__(
6372
self,
6473
nbit: int,
6574
group_size: int,
6675
target: str,
76+
bias: Optional[torch.Tensor] = None,
6777
):
6878
assert nbit <= 8
6979
self.nbit = nbit
7080
self.group_size = group_size
7181
self.target = target_from_str(target)
82+
self.bias = bias
7283

7384
def extra_repr(self):
74-
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}"
85+
return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}, bias={self.bias}"
7586

7687

7788
def _pack_weights_native(
@@ -81,7 +92,6 @@ def _pack_weights_native(
8192
layout: Layout,
8293
):
8394
assert isinstance(layout, Linear8BitActXBitWeightLayout)
84-
assert layout.target == Target.NATIVE
8595
nbit = layout.nbit
8696
group_size = layout.group_size
8797
has_weight_zeros = zero_point is not None
@@ -100,6 +110,12 @@ def _pack_weights_native(
100110
torch.empty(0, group_size, dtype=torch.int8),
101111
]
102112

113+
if TORCH_VERSION_AT_LEAST_2_6 and layout.target == Target.ATEN:
114+
in_features = int_data.shape[-1]
115+
out_features = int_data.shape[-2]
116+
int_data = int_data.add(8)
117+
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
118+
return torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, layout.bias, group_size, in_features, out_features)
103119
wzp_suffix = "" if has_weight_zeros else "0zp"
104120
return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")(
105121
*args
@@ -153,7 +169,7 @@ def get_layout(self) -> Layout:
153169
def get_plain(
154170
self,
155171
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
156-
if self.get_layout().target == Target.FALLBACK:
172+
if self.get_layout().target == Target.FALLBACK or self.get_layout().target == Target.ATEN:
157173
return self.packed_weight, self.scale, self.zero_point
158174
raise NotImplementedError(
159175
"get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback"
@@ -170,12 +186,17 @@ def from_plain(
170186
assert isinstance(layout, Linear8BitActXBitWeightLayout)
171187

172188
try:
173-
if layout.target == Target.NATIVE:
189+
if layout.target == Target.NATIVE or layout.target == Target.ATEN:
174190
packed_weight = _pack_weights_native(
175191
int_data, scale, zero_point, layout
176192
)
177193
scale = None
178194
zero_point = None
195+
# avoid storing bias tensor but indicate if Linear layer got bias on printing as
196+
# 1. aten_dynamic_quant already packed it in weights or
197+
# 2. its not needed by any other op
198+
if layout.bias is not None:
199+
layout.bias = True
179200
return cls(packed_weight, scale, zero_point, layout)
180201
except Exception as e:
181202
logger.warning(
@@ -216,7 +237,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
216237
)
217238

218239
def __tensor_flatten__(self):
219-
if self.get_layout().target == Target.NATIVE:
240+
if self.get_layout().target == Target.NATIVE or self.get_layout().target == Target.ATEN:
220241
return ["packed_weight"], [self.get_layout()]
221242

222243
# fallback
@@ -242,8 +263,11 @@ def _linear_int8_dynamic_activation_intx_weight_check(
242263
input_tensor, weight_tensor, bias
243264
):
244265
layout = weight_tensor.tensor_impl.get_layout()
245-
return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None
246-
266+
target_condition = False
267+
if isinstance(layout, Linear8BitActXBitWeightLayout) and layout.target == Target.ATEN:
268+
target_condition = True
269+
res = isinstance(layout, Linear8BitActXBitWeightLayout) and (bias is None or target_condition)
270+
return res
247271

248272
def _linear_int8_dynamic_activation_intx_weight_fallback_impl(
249273
input_tensor, weight_tensor, bias
@@ -353,6 +377,51 @@ def _impl_2d(input_tensor, weight_tensor):
353377
return res
354378

355379

380+
def _linear_int8_dynamic_activation_intx_weight_aten_impl(
381+
input_tensor, weight_tensor, bias
382+
):
383+
assert weight_tensor.tensor_impl.get_layout().target == Target.ATEN
384+
if weight_tensor.zero_point_domain != ZeroPointDomain.NONE:
385+
raise NotImplementedError(
386+
"MappingType.ASSYMETRIC in is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is aten"
387+
)
388+
assert (
389+
weight_tensor.tensor_impl.get_layout().nbit == 4
390+
), f"Only 4 bit is supported"
391+
assert (
392+
TORCH_VERSION_AT_LEAST_2_6 == 1
393+
), "Target.ATEN requires torch >= 2.6.0"
394+
# aten supports bias for kleidiAI but not for default fallback op
395+
if not torch.backends.kleidiai.is_available():
396+
print("TODO bias == None")
397+
assert bias == None
398+
399+
def _impl_2d(input_tensor, weight_tensor):
400+
assert input_tensor.dim() == 2
401+
assert weight_tensor.dim() == 2
402+
403+
m, k = input_tensor.shape
404+
n, k_ = weight_tensor.shape
405+
assert k_ == k
406+
group_size = weight_tensor.tensor_impl.get_layout().group_size
407+
packed_weight = weight_tensor.tensor_impl.packed_weight
408+
return torch.ops.aten._dyn_quant_matmul_4bit(
409+
input_tensor, packed_weight, group_size, k_, n)
410+
411+
if input_tensor.dim() == 2:
412+
return _impl_2d(input_tensor, weight_tensor)
413+
414+
assert input_tensor.dim() >= 3
415+
lead_shape = input_tensor.shape[0:-2]
416+
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
417+
n, k_ = weight_tensor.shape
418+
assert k_ == k
419+
420+
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
421+
res = res.reshape(*lead_shape, m, n)
422+
return res
423+
424+
356425
def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias):
357426
target = weight_tensor.tensor_impl.get_layout().target
358427
if target == Target.NATIVE:
@@ -365,6 +434,11 @@ def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor
365434
input_tensor, weight_tensor, bias
366435
)
367436

437+
if target == Target.ATEN:
438+
return _linear_int8_dynamic_activation_intx_weight_aten_impl(
439+
input_tensor, weight_tensor, bias
440+
)
441+
368442
assert False, f"Unknown target {target}"
369443

370444

torchao/experimental/quant_api.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,25 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from torchao.quantization.quant_api import (
8+
MappingType,
9+
)
710
import logging
8-
from typing import Optional
11+
from typing import Optional, Union
912

1013
import torch
1114
import torch.nn as nn
1215
from torch.ao.quantization.fx._decomposed import (
1316
dequantize_per_channel_group,
1417
quantize_per_channel_group,
1518
)
19+
from torchao.quantization.granularity import (
20+
PerRow,
21+
PerGroup,
22+
)
23+
from torchao.utils import (
24+
TORCH_VERSION_AT_LEAST_2_6,
25+
)
1626

1727
logger = logging.getLogger(__name__)
1828
logger.setLevel(logging.WARNING)
@@ -482,6 +492,104 @@ def quantize(self, model: nn.Module) -> nn.Module:
482492
return model
483493

484494

495+
_intx_granularity = Union[PerGroup, PerRow]
496+
497+
498+
def int8_dynamic_activation_intx_weight_v2(
499+
granularity: Optional[_intx_granularity] = PerGroup(32),
500+
nbit: int = 4,
501+
has_weight_zeros: bool = False,
502+
target: str = "native",
503+
mapping_type: MappingType = MappingType.ASYMMETRIC,
504+
has_bias: bool = False,
505+
):
506+
from torchao.experimental._linear_8bit_act_xbit_weight_layout import (
507+
Linear8BitActXBitWeightLayout,
508+
)
509+
from torchao.quantization.quant_api import (
510+
ZeroPointDomain,
511+
_get_linear_subclass_inserter,
512+
to_affine_quantized_intx,
513+
)
514+
515+
def get_quant_params(weight, has_weight_zeros: bool, mapping_type: MappingType, granularity: Optional[_intx_granularity]):
516+
scale_dtype = None
517+
zero_point_dtype = torch.int8
518+
zero_point_domain = (
519+
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE
520+
)
521+
target_dtype = torch.int32
522+
preserve_zero = has_weight_zeros
523+
if mapping_type == MappingType.ASYMMETRIC:
524+
pass
525+
elif mapping_type == MappingType.SYMMETRIC:
526+
assert (
527+
TORCH_VERSION_AT_LEAST_2_6 == 1
528+
), "MappingType.SYMMETRIC requires torch >= 2.6.0"
529+
zero_point_dtype = torch.int8
530+
zero_point_domain = ZeroPointDomain.NONE
531+
preserve_zero = True
532+
# The KleidiAI Groupwise kernel only supports bf16 scales for now
533+
if torch.backends.kleidiai.is_available():
534+
assert weight.dtype == torch.float32, f"Only float32 dtype is supported for KleidiAI int4 kernels. Provided {weight.dtype}"
535+
if isinstance(granularity, PerGroup):
536+
scale_dtype = torch.bfloat16
537+
else:
538+
raise ValueError(
539+
f"Only mapping_type ASYMMETRIC, SYMMETRIC are supported. Provided {mapping_type}"
540+
)
541+
542+
return target_dtype, zero_point_dtype, scale_dtype, preserve_zero, zero_point_domain
543+
544+
def apply(weight, bias: Optional[torch.Tensor] = None):
545+
if isinstance(granularity, PerGroup):
546+
group_size = granularity.group_size
547+
elif isinstance(granularity, PerRow):
548+
group_size = weight.shape[1]
549+
else:
550+
raise ValueError(
551+
f"Only granularity PerGroup(), PerRow() are supported. Provided {granularity}"
552+
)
553+
assert weight.shape[-1] % group_size == 0
554+
assert weight.device == torch.device("cpu"), "Only CPU is supported"
555+
use_hqq = False
556+
layout_args = [nbit, group_size, target]
557+
if bias is not None:
558+
layout_args.append(bias)
559+
layout = Linear8BitActXBitWeightLayout(*layout_args)
560+
# mapping_type = MappingType.ASYMMETRIC
561+
eps = torch.finfo(torch.float32).eps
562+
block_size = (1, group_size)
563+
# target_dtype = torch.int32
564+
quant_min = -(1 << (nbit - 1))
565+
quant_max = (1 << (nbit - 1)) - 1
566+
target_dtype, zero_point_dtype, scale_dtype, preserve_zero, zero_point_domain = get_quant_params(
567+
weight, has_weight_zeros, mapping_type, granularity)
568+
# Note: this works differently than other quantizers because the dynamic
569+
# activation quantization is fused with the kernel/op (and static activation quantization
570+
# is not supported).
571+
return to_affine_quantized_intx(
572+
weight,
573+
mapping_type,
574+
block_size,
575+
target_dtype,
576+
quant_min,
577+
quant_max,
578+
eps,
579+
scale_dtype=scale_dtype,
580+
zero_point_dtype=zero_point_dtype,
581+
preserve_zero=preserve_zero,
582+
zero_point_domain=zero_point_domain,
583+
_layout=layout,
584+
use_hqq=use_hqq,
585+
)
586+
587+
return _get_linear_subclass_inserter(
588+
apply,
589+
propagate_bias=has_bias
590+
)
591+
592+
485593
def int8_dynamic_activation_intx_weight(
486594
group_size: int = 128,
487595
nbit: int = 4,

torchao/quantization/quant_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,18 @@ def _linear_extra_repr(self):
450450
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
451451

452452

453-
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs):
453+
def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs):
454454
"""Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs)
455455
to the weight of linear module
456456
"""
457457

458458
def insert_subclass(lin):
459459
requires_grad = allow_requires_grad and lin.weight.requires_grad
460+
args = [lin.weight]
461+
if propagate_bias == True:
462+
args.append(lin.bias)
460463
lin.weight = torch.nn.Parameter(
461-
constructor(lin.weight, **kwargs), requires_grad=requires_grad
464+
constructor(*args, **kwargs), requires_grad=requires_grad
462465
)
463466
lin.extra_repr = types.MethodType(_linear_extra_repr, lin)
464467
return lin

0 commit comments

Comments
 (0)