Skip to content

Commit d985b80

Browse files
committed
Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c9e22bd ghstack-comment-id: 3460951962 Pull-Request: #3257
1 parent 4b67833 commit d985b80

File tree

4 files changed

+127
-33
lines changed

4 files changed

+127
-33
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
PerBlock,
2122
PerRow,
2223
PerTensor,
2324
quantize_,
@@ -61,20 +62,37 @@ def setUp(self):
6162
@unittest.skipIf(
6263
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
6364
)
64-
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
65-
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
66-
@common_utils.parametrize("compile", [True, False])
67-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
65+
# @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
66+
@common_utils.parametrize(
67+
"dtype",
68+
[
69+
torch.bfloat16,
70+
],
71+
)
72+
# @common_utils.parametrize("mode", ["dynamic", "weight-only"])
73+
@common_utils.parametrize(
74+
"mode",
75+
[
76+
"dynamic",
77+
],
78+
)
79+
# @common_utils.parametrize("compile", [True, False])
80+
@common_utils.parametrize("compile", [False])
81+
# @common_utils.parametrize("granularity", [PerTensor(), PerRow()])
82+
@common_utils.parametrize(
83+
"granularity", [(PerBlock((1, 128)), PerBlock((128, 128)))]
84+
)
6885
@common_utils.parametrize(
6986
"kernel_preference",
70-
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
87+
# [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
88+
[KernelPreference.TORCH],
7189
)
7290
# Inputs are (M,..), K, N
7391
@common_utils.parametrize(
7492
"sizes",
7593
[
7694
((128,), 256, 128),
77-
((32, 128), 64, 256),
95+
# ((32, 128), 64, 256),
7896
],
7997
)
8098
def test_fp8_linear_variants(

torchao/float8/inference.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
1515
from torchao.float8.types import FP8Granularity
1616
from torchao.quantization.granularity import (
17+
PerBlock,
1718
PerRow,
1819
PerTensor,
1920
)
@@ -196,6 +197,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
196197
)
197198

198199

200+
def _is_1_128_scaled(x: torch.Tensor) -> bool:
201+
"""Checks if a quantized tensor is scaled with a block size of 1x128
202+
Args:
203+
x: quantized tensor (should have `block_size` attribute)
204+
"""
205+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
206+
b = x.block_size
207+
return len(b) == 2 and b[0] == 1 and b[1] == 128
208+
209+
210+
def _is_128_128_scaled(x: torch.Tensor) -> bool:
211+
"""Checks if a quantized tensor is scaled with a block size of 128x128
212+
Args:
213+
x: quantized tensor (should have `block_size` attribute)
214+
"""
215+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
216+
b = x.block_size
217+
return len(b) == 2 and b[0] == 128 and b[1] == 128
218+
219+
220+
def _granularity_is_a_1_128_w_128_128(
221+
g: Union[
222+
FP8Granularity,
223+
Tuple[FP8Granularity, FP8Granularity],
224+
list[FP8Granularity],
225+
],
226+
) -> bool:
227+
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
228+
229+
199230
def _normalize_granularity(
200231
granularity: Optional[
201232
Union[
@@ -211,22 +242,23 @@ def _normalize_granularity(
211242
elif isinstance(granularity, (PerTensor, PerRow)):
212243
processed_granularity = (granularity, granularity)
213244
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
214-
if not (
215-
isinstance(granularity[0], (PerTensor, PerRow))
216-
and isinstance(granularity[1], (PerTensor, PerRow))
217-
):
218-
raise ValueError(
219-
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
220-
)
245+
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
246+
granularity[1], PerTensor
247+
)
248+
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
249+
granularity[1], PerRow
250+
)
251+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
252+
253+
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
254+
raise ValueError(f"Unsupported granularity types: {granularity}.")
221255
if not isinstance(granularity[0], type(granularity[1])):
222256
raise ValueError(
223-
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
257+
f"Different granularities for activation and weight are not supported: {granularity}."
224258
)
225259
processed_granularity = tuple(granularity)
226260
else:
227-
raise ValueError(
228-
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
229-
)
261+
raise ValueError(f"Invalid granularity specification: {granularity}.")
230262
return processed_granularity
231263

232264

@@ -243,12 +275,22 @@ def _check_hardware_support(
243275
AssertionError: If hardware doesn't support the requested granularity
244276
ValueError: If invalid granularity type is provided
245277
"""
246-
for _granularity in granularities:
247-
if not isinstance(_granularity, (PerTensor, PerRow)):
248-
raise ValueError(
249-
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
250-
)
278+
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
279+
granularities[1], PerTensor
280+
)
281+
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
282+
granularities[1], PerRow
283+
)
284+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
251285

286+
if is_per_tensor or is_per_row:
252287
assert is_sm_at_least_89() or is_MI300(), (
253288
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
254289
)
290+
elif is_a_1_128_w_128_128:
291+
# TODO(future PR): look into AMD support
292+
assert is_sm_at_least_89(), (
293+
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
294+
)
295+
else:
296+
raise ValueError(f"Invalid granularities {granularities}.")

torchao/quantization/quant_api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
Float8MMConfig,
6363
FP8Granularity,
6464
_check_hardware_support,
65+
_granularity_is_a_1_128_w_128_128,
6566
_normalize_granularity,
6667
)
6768
from torchao.quantization.linear_activation_weight_observed_tensor import (
@@ -1770,13 +1771,22 @@ def __post_init__(self):
17701771
torch._C._log_api_usage_once(
17711772
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
17721773
)
1773-
if self.mm_config is None:
1774-
self.mm_config = Float8MMConfig(use_fast_accum=True)
17751774
activation_granularity, weight_granularity = _normalize_granularity(
17761775
self.granularity
17771776
)
17781777
self.granularity = [activation_granularity, weight_granularity]
17791778

1779+
default_use_fast_accum = True
1780+
if _granularity_is_a_1_128_w_128_128(self.granularity):
1781+
assert self.activation_value_lb is None, "unimplemented"
1782+
assert self.activation_value_ub is None, "unimplemented"
1783+
assert self.kernel_preference is KernelPreference.TORCH, "unimplemented"
1784+
assert self.mm_config is None, "unimplemented"
1785+
default_use_fast_accum = False
1786+
1787+
if self.mm_config is None:
1788+
self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum)
1789+
17801790

17811791
# for bc
17821792
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
from torchao.float8.inference import (
1616
Float8MMConfig,
1717
FP8Granularity,
18+
_is_1_128_scaled,
19+
_is_128_128_scaled,
1820
_is_rowwise_scaled,
1921
_is_tensorwise_scaled,
2022
_slice_scale_for_dimension,
2123
addmm_float8_unwrapped_inference,
2224
preprocess_data,
2325
preprocess_scale,
2426
)
27+
from torchao.kernel.blockwise_quantization import (
28+
blockwise_fp8_gemm,
29+
)
2530
from torchao.quantization.granularity import PerRow, PerTensor
2631
from torchao.quantization.quant_primitives import (
2732
_choose_scale_float8,
@@ -337,19 +342,38 @@ def _(func, types, args, kwargs):
337342
"Input tensor must be rowwise block size"
338343
)
339344
w_scale = w_scale.transpose(-1, -2)
345+
elif _is_128_128_scaled(weight_tensor):
346+
assert _is_1_128_scaled(input_tensor), (
347+
"input_tensor must be 1x128 scaled"
348+
)
349+
w_scale = w_scale.transpose(-1, -2)
340350

341351
input_scale = preprocess_scale(input_scale, input_tensor.shape)
342352
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
343353

344-
return addmm_float8_unwrapped_inference(
345-
inpt_data,
346-
input_scale,
347-
w_data,
348-
w_scale,
349-
output_dtype=input_tensor.dtype,
350-
bias=bias,
351-
use_fast_accum=scaled_mm_config.use_fast_accum,
352-
).reshape(out_shape)
354+
if _is_128_128_scaled(weight_tensor):
355+
# TODO(future PR): add testing for torch._scaled_mm with
356+
# blockwise scaling on CUDA 12.9
357+
# TODO(future PR): add fbgemm_gpu_genai path if available
358+
assert _is_1_128_scaled(input_tensor), "unsupported"
359+
res = blockwise_fp8_gemm(
360+
inpt_data,
361+
input_scale,
362+
w_data.t(),
363+
w_scale,
364+
block_size=128,
365+
)
366+
else:
367+
res = addmm_float8_unwrapped_inference(
368+
inpt_data,
369+
input_scale,
370+
w_data,
371+
w_scale,
372+
output_dtype=input_tensor.dtype,
373+
bias=bias,
374+
use_fast_accum=scaled_mm_config.use_fast_accum,
375+
)
376+
return res.reshape(out_shape)
353377
else:
354378
assert not isinstance(input_tensor, TorchAOBaseTensor), (
355379
"Expecting input_tensor to be unquantized"

0 commit comments

Comments
 (0)