20
20
_ATEN_OP_OR_TORCH_FN_TABLE ,
21
21
_register_layout_cls ,
22
22
_get_layout_tensor_constructor ,
23
+ LayoutType ,
23
24
)
25
+ from typing import ClassVar
26
+ from dataclasses import dataclass
24
27
25
28
aten = torch .ops .aten
26
29
30
+ @dataclass (frozen = True )
31
+ class PlainLayoutType (LayoutType ):
32
+ pass
33
+
34
+ @dataclass (frozen = True )
35
+ class TensorCoreTiledLayoutType (LayoutType ):
36
+ inner_k_tiles : int = 8
37
+
38
+ def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
39
+ orig_out_features , orig_in_features = input .shape
40
+ in_features = find_multiple (orig_in_features , 1024 )
41
+ out_features = find_multiple (orig_out_features , 8 )
42
+ input = torch .nn .functional .pad (
43
+ input ,
44
+ (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
45
+ )
46
+ return input
47
+
48
+ def extra_repr (self ):
49
+ return f"inner_k_tiles={ self .inner_k_tiles } "
50
+
51
+
27
52
def _aqt_is_int8 (aqt ):
28
53
"""Check if an AffineQuantizedTensor is int8 quantized Tensor"""
29
54
return (
@@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor):
52
77
"""
53
78
Base class for the layout tensor for `AffineQuantizedTensor`
54
79
"""
55
- # this should be set for each layout class during registration
56
- extended_layout : Optional [ str ] = None
80
+ def get_plain ( self ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ]:
81
+ pass
57
82
58
- def get_plain ( ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
83
+ def get_layout_type ( self ) -> LayoutType :
59
84
pass
60
85
61
86
@classmethod
@@ -64,9 +89,15 @@ def from_plain(
64
89
int_data : torch .Tensor ,
65
90
scale : torch .Tensor ,
66
91
zero_point : torch .Tensor ,
92
+ layout_type : LayoutType ,
67
93
):
68
94
pass
69
95
96
+ def __repr__ (self ):
97
+ int_data , scale , zero_point = self .get_plain ()
98
+ layout_type = self .get_layout_type ()
99
+ return f"{ self .__class__ .__name__ } (int_data={ int_data } , scale={ scale } , zero_point={ zero_point } , layout_type={ layout_type } )"
100
+
70
101
def _get_to_kwargs (self , * args , ** kwargs ):
71
102
device , dtype , _ , memory_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
72
103
device = self .device if device is None else device
@@ -194,30 +225,17 @@ def from_float(
194
225
zero_point_dtype : Optional [torch .dtype ] = None ,
195
226
preserve_zero : bool = True ,
196
227
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
197
- extended_layout : str = "plain" ,
198
- # TODO: this is only for "tensor_core_tiled", need to figure out
199
- # the proper API for this arg
200
- inner_k_tiles : Optional [int ] = None ,
228
+ layout_type : LayoutType = PlainLayoutType (),
201
229
):
202
230
original_shape = input_float .shape
203
- if extended_layout == "tensor_core_tiled" :
204
- orig_out_features , orig_in_features = input_float .shape
205
- in_features = find_multiple (orig_in_features , 1024 )
206
- out_features = find_multiple (orig_out_features , 8 )
207
- input_float = torch .nn .functional .pad (
208
- input_float ,
209
- (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
210
- )
231
+ input_float = layout_type .pre_process (input_float )
211
232
212
233
scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , scale_dtype , zero_point_dtype , preserve_zero , zero_point_domain )
213
234
int_data = quantize_affine (input_float , block_size , scale , zero_point , target_dtype , quant_min , quant_max , zero_point_domain )
235
+ int_data = layout_type .post_process (int_data )
214
236
215
- layout_cls_ctr = get_layout_tensor_constructor (extended_layout )
216
- # TODO: this is temporary, need to come up with the proper UX
217
- if extended_layout == "tensor_core_tiled" :
218
- layout_tensor = layout_cls_ctr (int_data , scale , zero_point , inner_k_tiles )
219
- else :
220
- layout_tensor = layout_cls_ctr (int_data , scale , zero_point )
237
+ layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
238
+ layout_tensor = layout_tensor_ctr (int_data , scale , zero_point , layout_type )
221
239
return cls (
222
240
layout_tensor ,
223
241
block_size ,
@@ -229,8 +247,8 @@ def from_float(
229
247
)
230
248
231
249
@property
232
- def extended_layout (self ) -> str :
233
- return self .layout_tensor .extended_layout
250
+ def layout_type (self ) -> str :
251
+ return self .layout_tensor .layout_type
234
252
235
253
@classmethod
236
254
def __torch_function__ (cls , func , types , args = (), kwargs = None ):
@@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308
326
def implements (aten_ops_or_torch_fn ):
309
327
return _implements (AffineQuantizedTensor , aten_ops_or_torch_fn )
310
328
311
- def register_layout_cls (extended_layout : str ):
312
- return _register_layout_cls (AffineQuantizedTensor , extended_layout )
329
+ def register_layout_cls (layout_type_class : type ( LayoutType ) ):
330
+ return _register_layout_cls (AffineQuantizedTensor , layout_type_class )
313
331
314
- def get_layout_tensor_constructor (extended_layout : str ):
315
- return _get_layout_tensor_constructor (AffineQuantizedTensor , extended_layout )
332
+ def get_layout_tensor_constructor (layout_type_class : type ( LayoutType ) ):
333
+ return _get_layout_tensor_constructor (AffineQuantizedTensor , layout_type_class )
316
334
317
- @register_layout_cls ("plain" )
335
+ @register_layout_cls (PlainLayoutType )
318
336
class PlainAQTLayout (AQTLayout ):
319
337
"""
320
338
Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point
@@ -330,6 +348,7 @@ def __new__(
330
348
int_data : torch .Tensor ,
331
349
scale : torch .Tensor ,
332
350
zero_point : torch .Tensor ,
351
+ layout_type : LayoutType ,
333
352
):
334
353
kwargs = {}
335
354
kwargs ["device" ] = int_data .device
@@ -346,34 +365,39 @@ def __init__(
346
365
int_data : torch .Tensor ,
347
366
scale : torch .Tensor ,
348
367
zero_point : torch .Tensor ,
368
+ layout_type : LayoutType ,
349
369
):
350
370
self .int_data = int_data
351
371
self .scale = scale
352
372
self .zero_point = zero_point
373
+ self .layout_type = layout_type
353
374
354
375
def __tensor_flatten__ (self ):
355
- return ["int_data" , "scale" , "zero_point" ], []
376
+ return ["int_data" , "scale" , "zero_point" ], [self . layout_type ]
356
377
357
378
@classmethod
358
379
def __tensor_unflatten__ (
359
380
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
360
381
):
361
382
int_data , scale , zero_point = tensor_data_dict ["int_data" ], tensor_data_dict ["scale" ], tensor_data_dict ["zero_point" ]
362
- return cls (int_data , scale , zero_point )
383
+ layout_type , = tensor_attributes
384
+ return cls (int_data , scale , zero_point , layout_type )
363
385
364
386
def to (self , * args , ** kwargs ):
365
387
kwargs = self ._get_to_kwargs (* args , ** kwargs )
366
388
return self .__class__ (
367
389
self .int_data .to (kwargs ["device" ]),
368
390
self .scale .to (kwargs ["device" ]),
369
391
self .zero_point .to (kwargs ["device" ]),
392
+ self .layout_type ,
370
393
)
371
394
372
395
def _apply_fn_to_data (self , fn ):
373
396
return self .__class__ (
374
397
fn (self .int_data ),
375
398
fn (self .scale ),
376
399
fn (self .zero_point ),
400
+ self .layout_type ,
377
401
)
378
402
379
403
@classmethod
@@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
398
422
399
423
__torch_function__ = torch ._C ._disabled_torch_function_impl
400
424
401
- def get_plain (self ):
425
+ def get_plain (self ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
402
426
return self .int_data , self .scale , self .zero_point
403
427
428
+ def get_layout_type (self ) -> LayoutType :
429
+ return self .layout_type
430
+
404
431
@classmethod
405
432
def from_plain (
406
433
cls ,
407
434
int_data : torch .Tensor ,
408
435
scale : torch .Tensor ,
409
436
zero_point : torch .Tensor ,
437
+ layout_type : LayoutType ,
410
438
):
411
- return cls (int_data , scale , zero_point )
439
+ assert isinstance (layout_type , PlainLayoutType )
440
+ return cls (int_data , scale , zero_point , layout_type )
412
441
413
- @register_layout_cls ("tensor_core_tiled" )
442
+ @register_layout_cls (TensorCoreTiledLayoutType )
414
443
class TensorCoreTiledAQTLayout (AQTLayout ):
415
444
"""
416
445
Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only,
@@ -427,6 +456,7 @@ def __new__(
427
456
packed_weight : torch .Tensor ,
428
457
scale_and_zero : torch .Tensor ,
429
458
transposed : bool ,
459
+ layout_type : LayoutType ,
430
460
):
431
461
kwargs = {}
432
462
kwargs ["device" ] = packed_weight .device
@@ -443,31 +473,40 @@ def __init__(
443
473
packed_weight : torch .Tensor ,
444
474
scale_and_zero : torch .Tensor ,
445
475
transposed : bool ,
476
+ layout_type : LayoutType ,
446
477
):
447
478
self .packed_weight = packed_weight
448
479
self .scale_and_zero = scale_and_zero
449
480
self .transposed = False
481
+ self .layout_type = layout_type
450
482
451
483
def __tensor_flatten__ (self ):
452
- return ["packed_weight" , "scale_and_zero" ], [self .transposed ]
484
+ return ["packed_weight" , "scale_and_zero" ], [self .transposed , self . layout_type ]
453
485
454
486
@classmethod
455
487
def __tensor_unflatten__ (
456
488
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
457
489
):
458
490
packed_weight , scale_and_zero = tensor_data_dict ["packed_weight" ], tensor_data_dict ["scale_and_zero" ]
459
- transposed , = tensor_attributes
460
- return cls (packed_weight , scale_and_zero , transposed )
491
+ transposed , layout_type , = tensor_attributes
492
+ return cls (packed_weight , scale_and_zero , transposed , layout_type )
461
493
462
494
@classmethod
463
- def from_plain (cls , int_data , scale , zero_point , inner_k_tiles = 8 ):
495
+ def from_plain (
496
+ cls ,
497
+ int_data : torch .Tensor ,
498
+ scale : torch .Tensor ,
499
+ zero_point : torch .Tensor ,
500
+ layout_type : LayoutType
501
+ ):
502
+ assert isinstance (layout_type , TensorCoreTiledLayoutType )
464
503
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
465
504
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
466
- packed_weight = torch .ops .aten ._convert_weight_to_int4pack (int_data .to (torch .int32 ), inner_k_tiles )
505
+ packed_weight = torch .ops .aten ._convert_weight_to_int4pack (int_data .to (torch .int32 ), layout_type . inner_k_tiles )
467
506
scale = scale .reshape (int_data .shape [0 ], - 1 )
468
507
zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
469
508
scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point )
470
- return cls (packed_weight , scale_and_zero , False )
509
+ return cls (packed_weight , scale_and_zero , False , layout_type )
471
510
472
511
def to (self , * args , ** kwargs ):
473
512
kwargs = self ._get_to_kwargs (* args , ** kwargs )
@@ -477,18 +516,15 @@ def to(self, *args, **kwargs):
477
516
return self .__class__ (
478
517
self .packed_weight .to (device ),
479
518
self .scale_and_zero .to (device ),
480
- self .transposed
519
+ self .transposed ,
520
+ self .layout_type ,
481
521
)
482
522
483
523
def _apply_fn_to_data (self , fn ):
484
524
self .packed_weight = fn (self .packed_weight )
485
525
self .scale_and_zero = fn (self .scale_and_zero )
486
526
return self
487
527
488
- def __repr__ (self ):
489
- int_data , scale , zero_point = self .get_plain ()
490
- return f"TensorCoreTiledAQTLayout(int_data={ int_data } , scale={ scale } , zero_point={ zero_point } )"
491
-
492
528
@classmethod
493
529
def __torch_dispatch__ (cls , func , types , args , kwargs ):
494
530
kwargs = {} if kwargs is None else kwargs
@@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
511
547
512
548
__torch_function__ = torch ._C ._disabled_torch_function_impl
513
549
514
- def get_plain (self ):
550
+ def get_plain (self ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
515
551
from torchao .quantization .quant_primitives import (
516
552
ZeroPointDomain ,
517
553
quantize_affine ,
@@ -542,6 +578,9 @@ def get_plain(self):
542
578
int_data = quantize_affine (dequantized , block_size , scale , zero , target_dtype , quant_min , quant_max , zero_point_domain )
543
579
return int_data , scale , zero
544
580
581
+ def get_layout_type (self ) -> LayoutType :
582
+ return self .layout_type
583
+
545
584
def _quantized_linear_op (input_tensor , weight_qtensor , bias ):
546
585
"""
547
586
Quantized version of F.linear operator
@@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
565
604
is_cuda and
566
605
input_is_int8 and
567
606
input_tensor .dtype == weight_qtensor .dtype and
568
- input_tensor .extended_layout == "plain" and
569
- weight_qtensor .extended_layout == "plain"
607
+ isinstance ( input_tensor .layout_type , PlainLayoutType ) and
608
+ isinstance ( weight_qtensor .layout_type , PlainLayoutType )
570
609
):
571
610
#
572
611
# 1. do the matrix form of dot(X_i, W_j)
@@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
608
647
weight_qtensor .dtype == torch .bfloat16 and
609
648
len (weight_qtensor .shape ) == 2 and
610
649
weight_qtensor .zero_point_domain == ZeroPointDomain .FLOAT and
611
- weight_qtensor .extended_layout == "tensor_core_tiled"
650
+ isinstance ( weight_qtensor .layout_type , TensorCoreTiledLayoutType )
612
651
):
613
652
assert weight_qtensor .block_size [0 ] == 1 , f"Requires groupwise quantization, got block_size: { block_size } "
614
653
assert input_tensor .shape [- 1 ] == weight_qtensor .shape [1 ], (
@@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
651
690
weight_qtensor .block_size [0 ] == 1 and
652
691
weight_qtensor .block_size [1 ] == weight_qtensor .shape [1 ] and
653
692
weight_qtensor .zero_point_domain == ZeroPointDomain .INT and
654
- weight_qtensor .extended_layout == "plain"
693
+ isinstance ( weight_qtensor .layout_type , PlainLayoutType )
655
694
):
656
695
# TODO: enable cpu and mps efficient path
657
696
# per channel int8 weight only quantizated mm
0 commit comments