15
15
import torch
16
16
17
17
from torch .utils ._python_dispatch import return_and_correct_aliasing
18
- from torchao .quantization .quant_primitives import choose_qparams_affine , MappingType
18
+ from torchao .quantization .quant_primitives import (
19
+ choose_qparams_affine ,
20
+ MappingType ,
21
+ quantize_affine ,
22
+ dequantize_affine ,
23
+ )
19
24
from torchao .dtypes .utils import (
20
25
LayoutType ,
21
26
PlainLayoutType ,
24
29
25
30
aten = torch .ops .aten
26
31
32
+ # TODO: move to torchao/utils.py
33
+ def fill_defaults (args , n , defaults_tail ):
34
+ """
35
+ __torch_dispatch__ doesn't guarantee the number of arguments you are
36
+ passed (e.g., defaulted arguments are not passed); but usually it is
37
+ convenient to pad out the arguments list with defaults. This function
38
+ helps you do that.
39
+ Args:
40
+ args: the list of positional arguments passed to __torch_dispatch__
41
+ n: the number of arguments you are expecting to get
42
+ defaults_tail: default values for the arguments, starting from the
43
+ end of the list
44
+ Example:
45
+ >>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
46
+ [1, 2, 3, 4, 5]
47
+ >>> fill_defaults([1, 2, 3], 5, [None, None, None])
48
+ [1, 2, 3, None, None]]
49
+ """
50
+ if n - len (defaults_tail ) > len (args ):
51
+ raise RuntimeError ("not enough defaults to fill arguments" )
52
+ r = list (args )
53
+ for i in range (len (args ), n ):
54
+ r .append (defaults_tail [i - n + len (defaults_tail )])
55
+ return r
56
+
57
+
27
58
###############################
28
59
# Base Layout Tensor Subclass #
29
60
###############################
@@ -140,10 +171,10 @@ def from_float(
140
171
layout_type : LayoutType = PlainLayoutType (),
141
172
):
142
173
mapping_type = MappingType .SYMMETRIC
143
- block_size = input_float .shape
174
+ block_size = ( 1 , input_float .shape [ - 1 ])
144
175
dtype = torch .int16
145
- scale , _ = choose_qparams_affine (input_float , mapping_type , block_size , dtype )
146
- int_data = (input_float / scale ). to ( torch . int8 )
176
+ scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , dtype )
177
+ int_data = quantize_affine (input_float , block_size , scale , zero_point , dtype )
147
178
layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
148
179
layout_tensor = layout_tensor_ctr (int_data , scale , layout_type )
149
180
return cls (layout_tensor , input_float .shape )
@@ -160,7 +191,14 @@ def dequantize(self, output_dtype=None):
160
191
if output_dtype is None :
161
192
output_dtype = torch .get_default_dtype ()
162
193
int_data , scale = self .layout_tensor .get_plain ()
163
- return int_data .to (output_dtype ) * scale
194
+ transposed = False
195
+ block_size = (1 , int_data .shape [- 1 ])
196
+ if hasattr (self .layout_tensor , "transposed" ) and self .layout_tensor .transposed :
197
+ transposed = True
198
+ res = dequantize_affine (int_data , block_size , scale , None , int_data .dtype , output_dtype = output_dtype )
199
+ if transposed :
200
+ res = res .t ()
201
+ return res
164
202
165
203
def __repr__ (self ):
166
204
return (
@@ -203,6 +241,7 @@ def __new__(
203
241
cls ,
204
242
int_data : torch .Tensor ,
205
243
scale : torch .Tensor ,
244
+ transposed : bool ,
206
245
layout_type : LayoutType ,
207
246
):
208
247
kwargs = {}
@@ -219,22 +258,24 @@ def __init__(
219
258
self ,
220
259
int_data : torch .Tensor ,
221
260
scale : torch .Tensor ,
261
+ transposed : bool ,
222
262
layout_type : LayoutType ,
223
263
):
224
264
self .int_data = int_data
225
265
self .scale = scale
266
+ self .transposed = transposed
226
267
self .layout_type = layout_type
227
268
228
269
def __tensor_flatten__ (self ):
229
- return ["int_data" , "scale" ], [self .layout_type ]
270
+ return ["int_data" , "scale" ], [self .transposed , self . layout_type ]
230
271
231
272
@classmethod
232
273
def __tensor_unflatten__ (
233
274
cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
234
275
):
235
276
int_data , scale = tensor_data_dict ["int_data" ], tensor_data_dict ["scale" ]
236
- layout_type , = tensor_attributes
237
- return cls (int_data , scale , layout_type )
277
+ transposed , layout_type , = tensor_attributes
278
+ return cls (int_data , scale , transposed , layout_type )
238
279
239
280
@classmethod
240
281
def from_plain (
@@ -247,12 +288,13 @@ def from_plain(
247
288
extra metadata for packing etc.
248
289
"""
249
290
assert isinstance (layout_type , PlainLayoutType )
250
- return cls (int_data , scale , layout_type )
291
+ return cls (int_data , scale , False , layout_type )
251
292
252
293
def _apply_fn_to_data (self , fn ):
253
294
return self .__class__ (
254
295
fn (self .int_data ),
255
296
fn (self .scale ),
297
+ self .transposed ,
256
298
self .layout_type ,
257
299
)
258
300
@@ -265,8 +307,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
265
307
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
266
308
)
267
309
310
+ # Tensor parallel support START
311
+ elif func in [aten ._to_copy .default , aten .clone .default ]:
312
+ return return_and_correct_aliasing (
313
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
314
+ )
315
+ elif func is aten .split .Tensor :
316
+ int_data_list = func (args [0 ].int_data , * args [1 :], ** kwargs )
317
+ scale_list = func (args [0 ].scale , * args [1 :], ** kwargs )
318
+ out = [PlainMyDTypeLayout (int_data , scale , args [0 ].transposed , args [0 ].layout_type ) for int_data , scale in zip (int_data_list , scale_list )]
319
+ return out
320
+ elif func is aten .empty_like .default :
321
+ int_data_empty_like = func (args [0 ].int_data , * args [1 :], ** kwargs )
322
+ return PlainMyDTypeLayout (int_data_empty_like , args [0 ].scale , args [0 ].transposed , args [0 ].layout_type )
323
+ elif func is aten .slice .Tensor :
324
+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
325
+ if dim == 0 :
326
+ return return_and_correct_aliasing (
327
+ func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : aten .slice .Tensor (x , dim , start , end , step ))
328
+ )
329
+ elif dim == 1 :
330
+ return PlainMyDTypeLayout (aten .slice .Tensor (self .int_data , dim , start , end , step ), self .scale .view (- 1 , 1 ), self .transposed , self .layout_type )
331
+ else :
332
+ raise NotImplementedError (f"PlainMyDTypeLayout dispatch: attempting to run { func } , with dim={ dim } , that is not supported" )
333
+ elif func is aten .t .default :
334
+ return return_and_correct_aliasing (func , args , kwargs , PlainMyDTypeLayout (args [0 ].int_data , args [0 ].scale , not args [0 ].transposed , args [0 ].layout_type ))
335
+
336
+ # Tensor parallel support END
337
+
268
338
raise NotImplementedError (
269
- f"MyDTypeLayout dispatch: attempting to run { func } , this is not supported"
339
+ f"PlainMyDTypeLayout dispatch: attempting to run { func } , this is not supported"
270
340
)
271
341
272
342
#####################################################
@@ -315,15 +385,6 @@ def _(func, types, args, kwargs):
315
385
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
316
386
)
317
387
318
-
319
- class M (torch .nn .Module ):
320
- def __init__ (self , * args , ** kwargs ) -> None :
321
- super ().__init__ (* args , ** kwargs )
322
- self .linear = torch .nn .Linear (1024 , 1024 )
323
-
324
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
325
- return self .linear (x )
326
-
327
388
#####################
328
389
# Factory functions #
329
390
#####################
@@ -333,42 +394,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
394
########
334
395
# Test #
335
396
########
336
-
337
- def test ():
397
+ def main ():
338
398
from torchao .utils import benchmark_model
339
-
399
+
400
+ class M (torch .nn .Module ):
401
+ def __init__ (self ) -> None :
402
+ super ().__init__ ()
403
+ self .linear = torch .nn .Linear (1024 , 128 )
404
+
405
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
406
+ return self .linear (x )
407
+
340
408
m = M ()
341
- example_inputs = (100 * torch .randn (1024 , 1024 ),)
409
+ example_inputs = (100 * torch .randn (512 , 1024 ),)
342
410
NUM_WARMUPS = 10
343
411
NUM_RUNS = 100
344
-
412
+
345
413
for _ in range (NUM_WARMUPS ):
346
414
m (* example_inputs )
347
415
print ("before quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
348
-
416
+
349
417
compiled = torch .compile (m , mode = "max-autotune" )
350
418
for _ in range (NUM_WARMUPS ):
351
419
compiled (* example_inputs )
352
420
print ("after compile:" , benchmark_model (compiled , NUM_RUNS , example_inputs ))
353
-
421
+
354
422
# convert weights to quantized weights
355
423
m .linear .weight = torch .nn .Parameter (
356
424
to_my_dtype (m .linear .weight ), requires_grad = False
357
425
)
358
-
426
+
359
427
for _ in range (NUM_WARMUPS ):
360
428
m (* example_inputs )
361
-
429
+
362
430
print ("after quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
363
-
431
+
364
432
m = torch .compile (m , mode = "max-autotune" )
365
-
433
+
366
434
for _ in range (NUM_WARMUPS ):
367
435
m (* example_inputs )
368
-
436
+
369
437
# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
370
438
# we plan to add custom op example in the future and that will help us to get speedup
371
439
print ("after quantization and compile:" , benchmark_model (m , NUM_RUNS , example_inputs ))
372
440
373
441
if __name__ == "__main__" :
374
- test ()
442
+ main ()
0 commit comments