Skip to content

Commit 77e5905

Browse files
committed
Test cases
1 parent 0ba6a2c commit 77e5905

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

test/integration/test_integration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
AQInt8WeightOnlyQuantizedLinearWeight2,
7373
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
75-
75+
AQFloat8WeightOnlyQuantizedLinearWeight,
7676
)
7777
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7878
import os
@@ -744,6 +744,13 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
744744
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
745745
)
746746

747+
@parameterized.expand(COMMON_DEVICE_DTYPE)
748+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
749+
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
750+
self._test_lin_weight_subclass_impl(
751+
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
752+
)
753+
747754
@parameterized.expand(COMMON_DEVICE_DTYPE)
748755
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
749756
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/quantization/autoquant.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ def from_float(cls, weight):
501501
# AQInt8WeightOnlyQuantizedLinearWeight3,
502502
# TODO this gets picked in places where it makes perf worse, why?
503503
AQInt8DynamicallyQuantizedLinearWeight,
504-
AQFloat8WeightOnlyQuantizedLinearWeight,
505504
]
506505

507506
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
@@ -510,6 +509,11 @@ def from_float(cls, weight):
510509
AQInt4G64WeightOnlyQuantizedLinearWeight
511510
]
512511

512+
OTHER_AUTOQUANT_CLASS_LIST = [
513+
AQFloat8WeightOnlyQuantizedLinearWeight,
514+
]
515+
516+
513517
def _change_linears_to_autoquantizable(model, **kwargs):
514518
"""
515519
Converts all linear weight tensors to the

0 commit comments

Comments
 (0)