23
23
MappingType ,
24
24
ZeroPointDomain ,
25
25
)
26
+ from torchao .utils import (
27
+ TORCH_VERSION_AT_LEAST_2_6 ,
28
+ )
26
29
27
30
logger = logging .getLogger (__name__ )
28
31
logger .setLevel (logging .WARNING )
@@ -40,13 +43,16 @@ class Target(Enum):
40
43
41
44
NATIVE = auto ()
42
45
FALLBACK = auto ()
46
+ ATEN = auto ()
43
47
44
48
45
49
def target_from_str (target : str ) -> Target :
46
50
if target .lower () == "native" :
47
51
return Target .NATIVE
48
52
elif target .lower () == "fallback" :
49
53
return Target .FALLBACK
54
+ elif target .lower () == "aten" :
55
+ return Target .ATEN
50
56
else :
51
57
raise ValueError (f"Invalid target: { target } " )
52
58
@@ -56,22 +62,27 @@ class Linear8BitActXBitWeightLayout(Layout):
56
62
nbit : int
57
63
group_size : int
58
64
59
- # The target platform for the layout, either 'native' or 'fallback'.
65
+ # The target platform for the layout, 'native', 'fallback' or 'aten'
60
66
target : Target
61
67
68
+ # Allow bias access via layout
69
+ bias : Optional [torch .Tensor ] = None
70
+
62
71
def __init__ (
63
72
self ,
64
73
nbit : int ,
65
74
group_size : int ,
66
75
target : str ,
76
+ bias : Optional [torch .Tensor ] = None ,
67
77
):
68
78
assert nbit <= 8
69
79
self .nbit = nbit
70
80
self .group_size = group_size
71
81
self .target = target_from_str (target )
82
+ self .bias = bias
72
83
73
84
def extra_repr (self ):
74
- return f"nbit={ self .nbit } , group_size={ self .group_size } , target={ self .target } "
85
+ return f"nbit={ self .nbit } , group_size={ self .group_size } , target={ self .target } , bias= { self . bias } "
75
86
76
87
77
88
def _pack_weights_native (
@@ -81,7 +92,6 @@ def _pack_weights_native(
81
92
layout : Layout ,
82
93
):
83
94
assert isinstance (layout , Linear8BitActXBitWeightLayout )
84
- assert layout .target == Target .NATIVE
85
95
nbit = layout .nbit
86
96
group_size = layout .group_size
87
97
has_weight_zeros = zero_point is not None
@@ -100,6 +110,12 @@ def _pack_weights_native(
100
110
torch .empty (0 , group_size , dtype = torch .int8 ),
101
111
]
102
112
113
+ if TORCH_VERSION_AT_LEAST_2_6 and layout .target == Target .ATEN :
114
+ in_features = int_data .shape [- 1 ]
115
+ out_features = int_data .shape [- 2 ]
116
+ int_data = int_data .add (8 )
117
+ int_data = (int_data [::,1 ::2 ] << 4 | int_data [::,::2 ] ).to (torch .uint8 )
118
+ return torch .ops .aten ._dyn_quant_pack_4bit_weight (int_data , scale , layout .bias , group_size , in_features , out_features )
103
119
wzp_suffix = "" if has_weight_zeros else "0zp"
104
120
return getattr (torch .ops .torchao , f"_pack_8bit_act_{ nbit } bit{ wzp_suffix } _weight" )(
105
121
* args
@@ -153,7 +169,7 @@ def get_layout(self) -> Layout:
153
169
def get_plain (
154
170
self ,
155
171
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ]]:
156
- if self .get_layout ().target == Target .FALLBACK :
172
+ if self .get_layout ().target == Target .FALLBACK or self . get_layout (). target == Target . ATEN :
157
173
return self .packed_weight , self .scale , self .zero_point
158
174
raise NotImplementedError (
159
175
"get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback"
@@ -170,12 +186,17 @@ def from_plain(
170
186
assert isinstance (layout , Linear8BitActXBitWeightLayout )
171
187
172
188
try :
173
- if layout .target == Target .NATIVE :
189
+ if layout .target == Target .NATIVE or layout . target == Target . ATEN :
174
190
packed_weight = _pack_weights_native (
175
191
int_data , scale , zero_point , layout
176
192
)
177
193
scale = None
178
194
zero_point = None
195
+ # avoid storing bias tensor but indicate if Linear layer got bias on printing as
196
+ # 1. aten_dynamic_quant already packed it in weights or
197
+ # 2. its not needed by any other op
198
+ if layout .bias is not None :
199
+ layout .bias = True
179
200
return cls (packed_weight , scale , zero_point , layout )
180
201
except Exception as e :
181
202
logger .warning (
@@ -216,7 +237,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
216
237
)
217
238
218
239
def __tensor_flatten__ (self ):
219
- if self .get_layout ().target == Target .NATIVE :
240
+ if self .get_layout ().target == Target .NATIVE or self . get_layout (). target == Target . ATEN :
220
241
return ["packed_weight" ], [self .get_layout ()]
221
242
222
243
# fallback
@@ -242,8 +263,11 @@ def _linear_int8_dynamic_activation_intx_weight_check(
242
263
input_tensor , weight_tensor , bias
243
264
):
244
265
layout = weight_tensor .tensor_impl .get_layout ()
245
- return isinstance (layout , Linear8BitActXBitWeightLayout ) and bias is None
246
-
266
+ target_condition = False
267
+ if isinstance (layout , Linear8BitActXBitWeightLayout ) and layout .target == Target .ATEN :
268
+ target_condition = True
269
+ res = isinstance (layout , Linear8BitActXBitWeightLayout ) and (bias is None or target_condition )
270
+ return res
247
271
248
272
def _linear_int8_dynamic_activation_intx_weight_fallback_impl (
249
273
input_tensor , weight_tensor , bias
@@ -353,6 +377,51 @@ def _impl_2d(input_tensor, weight_tensor):
353
377
return res
354
378
355
379
380
+ def _linear_int8_dynamic_activation_intx_weight_aten_impl (
381
+ input_tensor , weight_tensor , bias
382
+ ):
383
+ assert weight_tensor .tensor_impl .get_layout ().target == Target .ATEN
384
+ if weight_tensor .zero_point_domain != ZeroPointDomain .NONE :
385
+ raise NotImplementedError (
386
+ "MappingType.ASSYMETRIC in is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is aten"
387
+ )
388
+ assert (
389
+ weight_tensor .tensor_impl .get_layout ().nbit == 4
390
+ ), f"Only 4 bit is supported"
391
+ assert (
392
+ TORCH_VERSION_AT_LEAST_2_6 == 1
393
+ ), "Target.ATEN requires torch >= 2.6.0"
394
+ # aten supports bias for kleidiAI but not for default fallback op
395
+ if not torch .backends .kleidiai .is_available ():
396
+ print ("TODO bias == None" )
397
+ assert bias == None
398
+
399
+ def _impl_2d (input_tensor , weight_tensor ):
400
+ assert input_tensor .dim () == 2
401
+ assert weight_tensor .dim () == 2
402
+
403
+ m , k = input_tensor .shape
404
+ n , k_ = weight_tensor .shape
405
+ assert k_ == k
406
+ group_size = weight_tensor .tensor_impl .get_layout ().group_size
407
+ packed_weight = weight_tensor .tensor_impl .packed_weight
408
+ return torch .ops .aten ._dyn_quant_matmul_4bit (
409
+ input_tensor , packed_weight , group_size , k_ , n )
410
+
411
+ if input_tensor .dim () == 2 :
412
+ return _impl_2d (input_tensor , weight_tensor )
413
+
414
+ assert input_tensor .dim () >= 3
415
+ lead_shape = input_tensor .shape [0 :- 2 ]
416
+ m , k = input_tensor .shape [- 2 ], input_tensor .shape [- 1 ]
417
+ n , k_ = weight_tensor .shape
418
+ assert k_ == k
419
+
420
+ res = _impl_2d (input_tensor .reshape (- 1 , k ), weight_tensor )
421
+ res = res .reshape (* lead_shape , m , n )
422
+ return res
423
+
424
+
356
425
def _linear_int8_dynamic_activation_intx_weight_impl (input_tensor , weight_tensor , bias ):
357
426
target = weight_tensor .tensor_impl .get_layout ().target
358
427
if target == Target .NATIVE :
@@ -365,6 +434,11 @@ def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor
365
434
input_tensor , weight_tensor , bias
366
435
)
367
436
437
+ if target == Target .ATEN :
438
+ return _linear_int8_dynamic_activation_intx_weight_aten_impl (
439
+ input_tensor , weight_tensor , bias
440
+ )
441
+
368
442
assert False , f"Unknown target { target } "
369
443
370
444
0 commit comments