Skip to content

Commit 7dff17a

Browse files
authored
Float8 autoquant weight only (#866)
1 parent 728d629 commit 7dff17a

File tree

5 files changed

+61
-4
lines changed

5 files changed

+61
-4
lines changed

test/integration/test_integration.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
AQInt8WeightOnlyQuantizedLinearWeight2,
7373
AQInt8WeightOnlyQuantizedLinearWeight3,
7474
AutoQuantizableLinearWeight,
75-
75+
AQFloat8WeightOnlyQuantizedLinearWeight,
7676
)
7777
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7878
import os
@@ -98,6 +98,7 @@
9898
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
9999

100100
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
101+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
101102

102103
def _int8wo_api(mod):
103104
if TORCH_VERSION_AT_LEAST_2_4:
@@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
744745
AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
745746
)
746747

748+
@parameterized.expand(COMMON_DEVICE_DTYPE)
749+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
750+
@unittest.skipIf(not is_H100, "Need H100 to run")
751+
def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
752+
self._test_lin_weight_subclass_impl(
753+
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
754+
)
755+
747756
@parameterized.expand(COMMON_DEVICE_DTYPE)
748757
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
749758
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

test/kernel/test_autotuner.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
logging.basicConfig(level=logging.INFO)
1818

19+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
1920

2021
class TestQuantFlow(unittest.TestCase):
2122

@@ -49,6 +50,25 @@ def test_int_mm(self, device, dtype):
4950
assert out32_2.dtype == out32_1.dtype
5051
torch.testing.assert_allclose(out32_1, out32_2)
5152

53+
@parameterized.expand(
54+
[
55+
("cuda", torch.bfloat16),
56+
("cuda", torch.float16),
57+
]
58+
)
59+
@unittest.skipIf(not is_H100, "Needs H100")
60+
def test_int_mm_float8(self, device, dtype):
61+
from torchao.kernel import intmm
62+
63+
dtype = torch.bfloat16
64+
m, k, n = (128, 64, 16)
65+
x = torch.randn(m, k, dtype=dtype, device=device)
66+
w = torch.randn(n, k, dtype=dtype, device=device).t()
67+
x_float8 = x.to(dtype=torch.float8_e4m3fn)
68+
w_float8 = w.to(dtype=torch.float8_e4m3fn)
69+
out32_1 = intmm.safe_int_mm(x_float8, w_float8)
70+
assert out32_1.dtype == torch.int32
71+
5272
@parameterized.expand(
5373
[
5474
("cuda", torch.bfloat16),

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ def from_hp_to_floatx(
335335
input_float: torch.Tensor,
336336
block_size: Tuple[int, ...],
337337
target_dtype: torch.dtype,
338-
scale_dtype: Optional[torch.dtype],
339338
layout_type: LayoutType,
339+
scale_dtype: Optional[torch.dtype] = None,
340340
):
341341

342342
if target_dtype in FP8_TYPES:

torchao/kernel/intmm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
6969
input = (
7070
input.contiguous()
7171
) # (it seems the transpose makes cublas check the above j constraint on i)
72-
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
72+
try:
73+
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
74+
except Exception:
75+
# fallback path, would run on H100 for float8 dtypes
76+
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
77+
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
7378
else:
7479
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
7580
"""

torchao/quantization/autoquant.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Int8WeightOnlyQuantizedLinearWeight,
1010
QuantizedLinearWeightBase,
1111
)
12-
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType
12+
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
1313
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
1414
from torch.utils._python_dispatch import return_and_correct_aliasing
1515
from .quant_primitives import (
@@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
477477
def from_float(cls, weight):
478478
return weight
479479

480+
class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
481+
"""
482+
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
483+
"""
484+
target_dtype: torch.dtype = torch.float8_e4m3fn
485+
486+
@staticmethod
487+
def _quantized_linear_op(act_mat, w_qtensor, bias):
488+
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)
489+
490+
@classmethod
491+
def from_float(cls, weight):
492+
block_size = (1, weight.shape[1])
493+
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())
494+
495+
480496
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
481497
DEFAULT_AUTOQUANT_CLASS_LIST = [
482498
AQFloatLinearWeight,
@@ -493,6 +509,11 @@ def from_float(cls, weight):
493509
AQInt4G64WeightOnlyQuantizedLinearWeight
494510
]
495511

512+
OTHER_AUTOQUANT_CLASS_LIST = [
513+
AQFloat8WeightOnlyQuantizedLinearWeight,
514+
]
515+
516+
496517
def _change_linears_to_autoquantizable(model, **kwargs):
497518
"""
498519
Converts all linear weight tensors to the
@@ -617,6 +638,8 @@ def autoquant(
617638
if set_inductor_config:
618639
torchao.quantization.utils.recommended_inductor_config_setter()
619640

641+
if qtensor_class_list in OTHER_AUTOQUANT_CLASS_LIST:
642+
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9), "float8 requires CUDA arch >= 8.9"
620643

621644
# perform initial swap from linear weights
622645
# to AutoQuantizableLinearWeight

0 commit comments

Comments
 (0)