Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
get_symmetric_quantization_config,
)

from torchao.quantization.subclass import (
to_aqt,
to_laqt,
from torchao.dtypes import (
to_aq,
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization.subclass import (
to_laq,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
Expand Down Expand Up @@ -429,17 +430,17 @@ def get_per_token_block_size(x):
# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()

def apply_weight_quant(weight):
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)

def apply_act_quant(weight):
return to_laqt(weight, input_quant_func)
return to_laq(weight, input_quant_func)

# note: order is important
m = quantize(m, apply_weight_quant)
Expand Down Expand Up @@ -484,7 +485,7 @@ def test_quantized_tensor_subclass_int4(self):
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def apply_weight_quant(weight):
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

m = quantize(m, apply_weight_quant)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
Expand Down Expand Up @@ -515,7 +516,7 @@ def test_quantized_tensor_subclass_int8(self):

def apply_weight_quant(weight):
block_size = (1, weight.shape[1])
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m = quantize(m, apply_weight_quant)

Expand Down Expand Up @@ -555,7 +556,7 @@ def get_per_token_block_size(x):
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
Expand All @@ -565,10 +566,10 @@ def get_per_token_block_size(x):

def apply_weight_quant(weight):
block_size = get_weight_block_size(weight)
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

def apply_act_quant(weight):
return to_laqt(weight, input_quant_func)
return to_laq(weight, input_quant_func)

m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)
Expand Down
3 changes: 3 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .nf4tensor import NF4Tensor, to_nf4
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_aq

__all__ = [
"NF4Tensor",
"to_nf4",
"UInt4Tensor"
"AffineQuantizedTensor",
"to_aq",
]
Loading