Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AWQConfig,
AWQStep,
)
from torchao.quantization import PerBlock
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Expand All @@ -45,6 +46,9 @@
Float8DynamicActivationFloat8WeightConfig(),
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
Float8DynamicActivationFloat8WeightConfig(granularity=[PerRow(), PerRow()]),
Float8DynamicActivationFloat8WeightConfig(
granularity=[PerBlock([1, 128]), PerBlock([128, 128])]
),
Float8WeightOnlyConfig(
weight_dtype=torch.float8_e4m3fn,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def setUp(self):
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity",
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
[PerTensor(), PerRow(), (PerBlock([1, 128]), PerBlock([128, 128]))],
)
@common_utils.parametrize(
"kernel_preference",
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_fp8_linear_variants(
elif mode == "weight-only":
return unittest.skip("unimplemented")

elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
elif granularity == (PerBlock([1, 128]), PerBlock([128, 128])):
if dtype is not torch.bfloat16:
return unittest.skip("unimplemented")
elif mode != "dynamic":
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_fp8_linear_variants(
assert qs1.shape == (N, 1)
assert qs2.shape == (K, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert granularity == (PerBlock([1, 128]), PerBlock([128, 128]))
assert qs1.shape == (N // 128, K // 128)
assert qs2.shape == (K // 128, N // 128)

Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def run_evaluation(
)
if quantization == "float8_a1x128_w128x128":
config = Float8DynamicActivationFloat8WeightConfig(
granularity=(PerBlock((1, 128)), PerBlock((128, 128))),
granularity=(PerBlock([1, 128]), PerBlock([128, 128])),
activation_value_lb=1e-12,
)
# TODO(future): all workflows in this file should be skipping quantization
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _granularity_is_a_1_128_w_128_128(
list[FP8Granularity],
],
) -> bool:
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
return len(g) == 2 and g[0] == PerBlock([1, 128]) and g[1] == PerBlock([128, 128])


def _normalize_granularity(
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,8 @@ class PerBlock(Granularity):
# 1. `block_size` in this class can support tensors of multiple ranks
# 2. `block_size` in other places in the codebase has rank equal to the
# corresponding tensor
# TODO(future PR): change to list or support serialization with tuples,
# currently serialization only works when `block_size` is specified as a
# list. Example error:
# https://gist.github.com/vkuzo/ab4d6aec83cb98ad9417898d2c024a2c
block_size: tuple[int, ...]
1 change: 0 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,6 @@ def __post_init__(self):
KernelPreference.AUTO,
KernelPreference.TORCH,
), "unimplemented"
assert self.mm_config is None, "unimplemented"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supported now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is tested with the serialization flow. When the object is created a default value of mm_config is set, and then when object is saved and reloaded, that value then makes it to the constructor.

assert self.version >= 2, "unimplemented"
default_use_fast_accum = False

Expand Down
Loading