Skip to content

Commit f8f74c7

Browse files
authored
Move AffineQuantizedTensor to torchao/dtypes (#272)
Summary: att Test Plan: regression tests in test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags:
1 parent bc46bdc commit f8f74c7

File tree

4 files changed

+465
-461
lines changed

4 files changed

+465
-461
lines changed

test/quantization/test_quant_api.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@
1818
get_symmetric_quantization_config,
1919
)
2020

21-
from torchao.quantization.subclass import (
22-
to_aqt,
23-
to_laqt,
21+
from torchao.dtypes import (
22+
to_aq,
2423
AffineQuantizedTensor,
25-
LinearActQuantizedTensor,
2624
)
2725
from torchao.quantization.quant_primitives import (
2826
MappingType,
2927
ZeroPointDomain,
3028
)
31-
29+
from torchao.quantization.subclass import (
30+
to_laq,
31+
LinearActQuantizedTensor,
32+
)
3233
from torchao.quantization.quant_api import (
3334
_replace_with_custom_fn_if_matches_filter,
3435
apply_dynamic_quant,
@@ -429,17 +430,17 @@ def get_per_token_block_size(x):
429430
# input settings
430431
input_mapping_type = MappingType.ASYMMETRIC
431432
input_target_dtype = torch.int8
432-
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
433+
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
433434

434435
m = ToyLinearModel().eval()
435436
m_copy = copy.deepcopy(m)
436437
example_inputs = m.example_inputs()
437438

438439
def apply_weight_quant(weight):
439-
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
440+
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
440441

441442
def apply_act_quant(weight):
442-
return to_laqt(weight, input_quant_func)
443+
return to_laq(weight, input_quant_func)
443444

444445
# note: order is important
445446
m = quantize(m, apply_weight_quant)
@@ -484,7 +485,7 @@ def test_quantized_tensor_subclass_int4(self):
484485
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
485486

486487
def apply_weight_quant(weight):
487-
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
488+
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
488489

489490
m = quantize(m, apply_weight_quant)
490491
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
@@ -515,7 +516,7 @@ def test_quantized_tensor_subclass_int8(self):
515516

516517
def apply_weight_quant(weight):
517518
block_size = (1, weight.shape[1])
518-
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
519+
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
519520

520521
m = quantize(m, apply_weight_quant)
521522

@@ -555,7 +556,7 @@ def get_per_token_block_size(x):
555556
input_eps = 1e-5
556557
input_quant_min = -127
557558
input_quant_max = 127
558-
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
559+
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
559560

560561
# use 1024 so that we don't need padding
561562
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
@@ -565,10 +566,10 @@ def get_per_token_block_size(x):
565566

566567
def apply_weight_quant(weight):
567568
block_size = get_weight_block_size(weight)
568-
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
569+
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
569570

570571
def apply_act_quant(weight):
571-
return to_laqt(weight, input_quant_func)
572+
return to_laq(weight, input_quant_func)
572573

573574
m = quantize(m, apply_weight_quant)
574575
m = quantize(m, apply_act_quant)

torchao/dtypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
from .uint4 import UInt4Tensor
3+
from .aqt import AffineQuantizedTensor, to_aq
34

45
__all__ = [
56
"NF4Tensor",
67
"to_nf4",
78
"UInt4Tensor"
9+
"AffineQuantizedTensor",
10+
"to_aq",
811
]

0 commit comments

Comments
 (0)