Skip to content

Commit 42c2376

Browse files
authored
Factor out the specific configurations to helper functions (#286)
Summary: int4wo, int8wo, int8dyn, 8da4w are specific configurations for quantize function, we factor that out in the PR so they are easy to use Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent a7483f2 commit 42c2376

File tree

2 files changed

+119
-91
lines changed

2 files changed

+119
-91
lines changed

test/quantization/test_quant_api.py

Lines changed: 9 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
Quantizer,
3838
TwoStepQuantizer,
3939
quantize,
40+
get_apply_8da4w_quant,
41+
get_apply_int4wo_quant,
42+
get_apply_int8wo_quant,
43+
get_apply_int8dyn_quant,
4044
)
4145
from torchao.quantization.utils import (
4246
TORCH_VERSION_AFTER_2_3,
@@ -416,42 +420,11 @@ def test_eval_wrapper(self):
416420
# TODO: move to a separate test file
417421
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
418422
def test_quantized_tensor_subclass_8da4w(self):
419-
# weight settings
420423
groupsize = 32
421-
mapping_type = MappingType.SYMMETRIC
422-
block_size = (1, groupsize)
423-
target_dtype = torch.int8
424-
eps = torch.finfo(torch.float32).eps
425-
quant_min = -8
426-
quant_max = 7
427-
428-
# TODO: make a general helper function?
429-
# input settings
430-
def get_per_token_block_size(x):
431-
block_size = []
432-
for i in range(len(x.shape)-1):
433-
block_size.append(1)
434-
block_size.append(x.shape[-1])
435-
return block_size
436-
437-
# input settings
438-
input_mapping_type = MappingType.ASYMMETRIC
439-
input_target_dtype = torch.int8
440-
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
441-
442424
m = ToyLinearModel().eval()
443425
m_copy = copy.deepcopy(m)
444426
example_inputs = m.example_inputs()
445-
446-
def apply_weight_quant(weight):
447-
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
448-
449-
def apply_act_quant(weight):
450-
return to_laq(weight, input_quant_func)
451-
452-
# note: order is important
453-
m = quantize(m, apply_weight_quant)
454-
m = quantize(m, apply_act_quant)
427+
m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize))
455428

456429
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
457430
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -474,27 +447,13 @@ def apply_act_quant(weight):
474447
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
475448
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
476449
def test_quantized_tensor_subclass_int4(self):
477-
# weight settings
478-
groupsize = 32
479-
mapping_type = MappingType.ASYMMETRIC
480-
block_size = (1, groupsize)
481-
target_dtype = torch.int32
482-
quant_min = 0
483-
quant_max = 15
484-
eps = 1e-6
485-
preserve_zero = False
486-
zero_point_dtype = torch.bfloat16
487-
zero_point_domain = ZeroPointDomain.FLOAT
488-
489450
# use 1024 so that we don't need padding
490451
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
491452
m_copy = copy.deepcopy(m)
492453
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
493454

494-
def apply_weight_quant(weight):
495-
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
496-
497-
m = quantize(m, apply_weight_quant)
455+
groupsize = 32
456+
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
498457
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
499458
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
500459

@@ -511,21 +470,11 @@ def apply_weight_quant(weight):
511470
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
512471
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
513472
def test_quantized_tensor_subclass_int8(self):
514-
# weight settings
515-
mapping_type = MappingType.SYMMETRIC
516-
target_dtype = torch.int8
517-
eps = torch.finfo(torch.float32).eps
518-
zero_point_dtype = torch.int64
519-
520473
m = ToyLinearModel().eval().to(torch.bfloat16)
521474
m_copy = copy.deepcopy(m)
522475
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
523476

524-
def apply_weight_quant(weight):
525-
block_size = (1, weight.shape[1])
526-
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
527-
528-
m = quantize(m, apply_weight_quant)
477+
m = quantize(m, get_apply_int8wo_quant())
529478

530479
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
531480
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
@@ -543,43 +492,12 @@ def apply_weight_quant(weight):
543492
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
544493
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
545494
def test_quantized_tensor_subclass_int8_dyn_quant(self):
546-
# weight settings
547-
mapping_type = MappingType.SYMMETRIC
548-
def get_weight_block_size(x):
549-
return (1, x.shape[1])
550-
target_dtype = torch.int8
551-
eps = torch.finfo(torch.float32).eps
552-
zero_point_dtype = torch.int64
553-
554-
# input settings
555-
def get_per_token_block_size(x):
556-
block_size = list(x.shape)
557-
for i in range(len(block_size)-1):
558-
block_size[i] = 1
559-
return block_size
560-
561-
input_mapping_type = MappingType.SYMMETRIC
562-
input_target_dtype = torch.int8
563-
input_eps = 1e-5
564-
input_quant_min = -127
565-
input_quant_max = 127
566-
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
567-
568495
# use 1024 so that we don't need padding
569496
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
570497
m_copy = copy.deepcopy(m)
571498
# setting batch_size to 20 to be compatible with the kernel
572499
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))
573-
574-
def apply_weight_quant(weight):
575-
block_size = get_weight_block_size(weight)
576-
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
577-
578-
def apply_act_quant(weight):
579-
return to_laq(weight, input_quant_func)
580-
581-
m = quantize(m, apply_weight_quant)
582-
m = quantize(m, apply_act_quant)
500+
m = quantize(m, get_apply_int8dyn_quant())
583501

584502
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
585503
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)

torchao/quantization/quant_api.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
Int8DynamicallyQuantizedLinearWeight,
3333
Int8WeightOnlyQuantizedLinearWeight,
3434
QuantizedLinearWeightBase,
35+
to_laq,
36+
)
37+
38+
from .quant_primitives import (
39+
MappingType,
40+
ZeroPointDomain,
3541
)
3642
from .weight_only import WeightOnlyInt8QuantLinear
3743
from .unified import Quantizer, TwoStepQuantizer
@@ -56,6 +62,10 @@
5662
"quantize",
5763
"autoquant",
5864
"_get_subclass_inserter",
65+
"get_apply_8da4w_quant",
66+
"get_apply_int4wo_quant",
67+
"get_apply_int8wo_quant",
68+
"get_apply_int8dyn_quant",
5969
]
6070

6171
if TORCH_VERSION_AFTER_2_3:
@@ -287,3 +297,103 @@ def filter_fn(module, fqn):
287297
_is_linear if filter_fn is None else filter_fn,
288298
)
289299
return model
300+
301+
def get_apply_8da4w_quant(groupsize=32):
302+
303+
def apply_8da4w_quant(weight):
304+
# avoid circular dep
305+
from torchao.dtypes.aqt import to_aq
306+
307+
# weight settings
308+
mapping_type = MappingType.SYMMETRIC
309+
block_size = (1, groupsize)
310+
target_dtype = torch.int8
311+
eps = torch.finfo(torch.float32).eps
312+
quant_min = -8
313+
quant_max = 7
314+
315+
# TODO: make a general helper function?
316+
# input settings
317+
def get_per_token_block_size(x):
318+
block_size = []
319+
for i in range(len(x.shape)-1):
320+
block_size.append(1)
321+
block_size.append(x.shape[-1])
322+
return block_size
323+
324+
# input settings
325+
input_mapping_type = MappingType.ASYMMETRIC
326+
input_target_dtype = torch.int8
327+
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
328+
329+
weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
330+
weight = to_laq(weight, input_quant_func)
331+
return weight
332+
333+
return apply_8da4w_quant
334+
335+
336+
def get_apply_int4wo_quant(groupsize=32):
337+
def apply_int4wo_quant(weight):
338+
# avoid circular dep
339+
from torchao.dtypes.aqt import to_aq
340+
341+
groupsize = 32
342+
mapping_type = MappingType.ASYMMETRIC
343+
block_size = (1, groupsize)
344+
target_dtype = torch.int32
345+
quant_min = 0
346+
quant_max = 15
347+
eps = 1e-6
348+
preserve_zero = False
349+
zero_point_dtype = torch.bfloat16
350+
zero_point_domain = ZeroPointDomain.FLOAT
351+
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
352+
353+
return apply_int4wo_quant
354+
355+
356+
def get_apply_int8wo_quant():
357+
def apply_int8wo_quant(weight):
358+
# avoid circular dep
359+
from torchao.dtypes.aqt import to_aq
360+
361+
mapping_type = MappingType.SYMMETRIC
362+
target_dtype = torch.int8
363+
eps = torch.finfo(torch.float32).eps
364+
zero_point_dtype = torch.int64
365+
block_size = (1, weight.shape[1])
366+
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
367+
return apply_int8wo_quant
368+
369+
def get_apply_int8dyn_quant():
370+
def apply_int8dyn_quant(weight):
371+
# avoid circular dep
372+
from torchao.dtypes.aqt import to_aq
373+
# weight settings
374+
mapping_type = MappingType.SYMMETRIC
375+
def get_weight_block_size(x):
376+
return (1, x.shape[1])
377+
target_dtype = torch.int8
378+
eps = torch.finfo(torch.float32).eps
379+
zero_point_dtype = torch.int64
380+
381+
# input settings
382+
def get_per_token_block_size(x):
383+
block_size = list(x.shape)
384+
for i in range(len(block_size)-1):
385+
block_size[i] = 1
386+
return block_size
387+
388+
input_mapping_type = MappingType.SYMMETRIC
389+
input_target_dtype = torch.int8
390+
input_eps = 1e-5
391+
input_quant_min = -127
392+
input_quant_max = 127
393+
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
394+
395+
block_size = get_weight_block_size(weight)
396+
weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
397+
weight = to_laq(weight, input_quant_func)
398+
return weight
399+
return apply_int8dyn_quant

0 commit comments

Comments
 (0)