Skip to content

Commit 4cdca74

Browse files
committed
make float8 a1x128_w128x128 granularity serializeable
Summary: Change the granularity for this scaling type from `PerBlock((1, 128)), PerBlock((128, 128))` to `PerBlock([1, 128]), PerBlock([128, 128])`, to get around the current limitation of config serialization not supporting tuples. Error when serializing tuples: https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c Supporting tuples and lists would be a better long term fix, but the workaround will unblock further work in the short term. Test Plan: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 68799cc ghstack-comment-id: 3480842282 Pull-Request: #3279
1 parent 5537624 commit 4cdca74

File tree

6 files changed

+13
-6
lines changed

6 files changed

+13
-6
lines changed

test/core/test_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AWQConfig,
2424
AWQStep,
2525
)
26+
from torchao.quantization import PerBlock
2627
from torchao.quantization.quant_api import (
2728
Float8DynamicActivationFloat8WeightConfig,
2829
Float8DynamicActivationInt4WeightConfig,
@@ -45,6 +46,9 @@
4546
Float8DynamicActivationFloat8WeightConfig(),
4647
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
4748
Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]),
49+
Float8DynamicActivationFloat8WeightConfig(
50+
granularity=[PerBlock([1, 128]), PerBlock([128, 128])]
51+
),
4852
Float8WeightOnlyConfig(
4953
weight_dtype=torch.float8_e4m3fn,
5054
),

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def setUp(self):
9090
@common_utils.parametrize("compile", [True, False])
9191
@common_utils.parametrize(
9292
"granularity",
93-
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
93+
[PerTensor(), PerRow(), (PerBlock([1, 128]), PerBlock([128, 128]))],
9494
)
9595
@common_utils.parametrize(
9696
"kernel_preference",
@@ -124,7 +124,7 @@ def test_fp8_linear_variants(
124124
elif mode == "weight-only":
125125
return unittest.skip("unimplemented")
126126

127-
elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
127+
elif granularity == (PerBlock([1, 128]), PerBlock([128, 128])):
128128
if dtype is not torch.bfloat16:
129129
return unittest.skip("unimplemented")
130130
elif mode != "dynamic":
@@ -198,7 +198,7 @@ def test_fp8_linear_variants(
198198
assert qs1.shape == (N, 1)
199199
assert qs2.shape == (K, 1)
200200
else:
201-
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
201+
assert granularity == (PerBlock([1, 128]), PerBlock([128, 128]))
202202
assert qs1.shape == (N // 128, K // 128)
203203
assert qs2.shape == (K // 128, N // 128)
204204

torchao/_models/llama/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def run_evaluation(
173173
)
174174
if quantization == "float8_a1x128_w128x128":
175175
config = Float8DynamicActivationFloat8WeightConfig(
176-
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
176+
granularity=(PerBlock([1, 128]), PerBlock([128, 128])),
177177
activation_value_lb=1e-12,
178178
)
179179
# TODO(future): all workflows in this file should be skipping quantization

torchao/float8/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _granularity_is_a_1_128_w_128_128(
225225
list[FP8Granularity],
226226
],
227227
) -> bool:
228-
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
228+
return len(g) == 2 and g[0] == PerBlock([1, 128]) and g[1] == PerBlock([128, 128])
229229

230230

231231
def _normalize_granularity(

torchao/quantization/granularity.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,8 @@ class PerBlock(Granularity):
126126
# 1. `block_size` in this class can support tensors of multiple ranks
127127
# 2. `block_size` in other places in the codebase has rank equal to the
128128
# corresponding tensor
129+
# TODO(future PR): change to list or support serialization with tuples,
130+
# currently serialization only works when `block_size` is specified as a
131+
# list. Example error:
132+
# https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c
129133
block_size: tuple[int, ...]

torchao/quantization/quant_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1782,7 +1782,6 @@ def __post_init__(self):
17821782
KernelPreference.AUTO,
17831783
KernelPreference.TORCH,
17841784
), "unimplemented"
1785-
assert self.mm_config is None, "unimplemented"
17861785
assert self.version >= 2, "unimplemented"
17871786
default_use_fast_accum = False
17881787

0 commit comments

Comments
 (0)