diff --git a/test/core/test_config.py b/test/core/test_config.py index 0df31194ac..3fb9d435fa 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -23,6 +23,7 @@ AWQConfig, AWQStep, ) +from torchao.quantization import PerBlock from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, @@ -45,6 +46,9 @@ Float8DynamicActivationFloat8WeightConfig(), Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]), + Float8DynamicActivationFloat8WeightConfig( + granularity=[PerBlock([1, 128]), PerBlock([128, 128])] + ), Float8WeightOnlyConfig( weight_dtype=torch.float8_e4m3fn, ), diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 4871b48849..884dc3f798 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -91,7 +91,7 @@ def setUp(self): @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize( "granularity", - [PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))], + [PerTensor(), PerRow(), (PerBlock([1, 128]), PerBlock([128, 128]))], ) @common_utils.parametrize( "kernel_preference", @@ -125,7 +125,7 @@ def test_fp8_linear_variants( elif mode == "weight-only": return unittest.skip("unimplemented") - elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))): + elif granularity == (PerBlock([1, 128]), PerBlock([128, 128])): if dtype is not torch.bfloat16: return unittest.skip("unimplemented") elif mode != "dynamic": @@ -199,7 +199,7 @@ def test_fp8_linear_variants( assert qs1.shape == (N, 1) assert qs2.shape == (K, 1) else: - assert granularity == (PerBlock((1, 128)), PerBlock((128, 128))) + assert granularity == (PerBlock([1, 128]), PerBlock([128, 128])) assert qs1.shape == (N // 128, K // 128) assert qs2.shape == (K // 128, N // 128) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 676d3569e7..df071fe9d2 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -173,7 +173,7 @@ def run_evaluation( ) if quantization == "float8_a1x128_w128x128": config = Float8DynamicActivationFloat8WeightConfig( - granularity=(PerBlock((1, 128)), PerBlock((128, 128))), + granularity=(PerBlock([1, 128]), PerBlock([128, 128])), activation_value_lb=1e-12, ) # TODO(future): all workflows in this file should be skipping quantization diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 19c77d43d7..212df9c5db 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -225,7 +225,7 @@ def _granularity_is_a_1_128_w_128_128( list[FP8Granularity], ], ) -> bool: - return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128)) + return len(g) == 2 and g[0] == PerBlock([1, 128]) and g[1] == PerBlock([128, 128]) def _normalize_granularity( diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index ccf7099c54..d83032d7be 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -126,4 +126,8 @@ class PerBlock(Granularity): # 1. `block_size` in this class can support tensors of multiple ranks # 2. `block_size` in other places in the codebase has rank equal to the # corresponding tensor + # TODO(future PR): change to list or support serialization with tuples, + # currently serialization only works when `block_size` is specified as a + # list. Example error: + # https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c block_size: tuple[int, ...] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 235cd85a0f..59b65d8841 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1782,7 +1782,6 @@ def __post_init__(self): KernelPreference.AUTO, KernelPreference.TORCH, ), "unimplemented" - assert self.mm_config is None, "unimplemented" assert self.version >= 2, "unimplemented" default_use_fast_accum = False