Skip to content

Commit b18c876

Browse files
committed
Added op support for float8
1 parent 9229df9 commit b18c876

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed
Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
22
from torch.testing._internal.common_utils import run_tests
3-
from torchao.quantization import int8_weight_only
3+
from torchao.quantization import int8_weight_only, float8_weight_only
44

5-
class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
6-
pass
5+
class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
6+
QUANT_METHOD_FN = staticmethod(int8_weight_only)
7+
copy_tests(TorchAOTensorParallelTestCase, TestInt8woAffineQuantizedTensorParallel, "int8wo_tp")
78

8-
9-
copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")
9+
# Uncomment the following to test float8wo_affine : works only on H100
10+
class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
11+
QUANT_METHOD_FN = staticmethod(float8_weight_only)
12+
copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp")
1013

1114
if __name__ == "__main__":
1215
run_tests()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
10941094
return return_and_correct_aliasing(
10951095
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
10961096
)
1097-
if func is aten.clone.default:
1097+
elif func is aten.clone.default:
10981098
return return_and_correct_aliasing(
10991099
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
11001100
)
1101-
if func is aten.t.default:
1101+
elif func is aten.t.default:
11021102
"""we don't need to repack the weight and just rely on external
11031103
shape being changed and record the status of transpose/no-transpose
11041104
"""
11051105
args[0].transposed = not args[0].transposed
11061106
return return_and_correct_aliasing(func, args, kwargs, args[0])
1107-
1108-
raise NotImplementedError(
1109-
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
1110-
)
1107+
elif func is aten.slice.Tensor:
1108+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
1109+
if dim == 0:
1110+
return return_and_correct_aliasing(
1111+
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
1112+
)
1113+
elif dim == 1:
1114+
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
1115+
return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type)
1116+
else:
1117+
raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
1118+
else:
1119+
raise NotImplementedError(
1120+
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
1121+
)
11111122

11121123
__torch_function__ = torch._C._disabled_torch_function_impl
11131124

@@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl(
16441655
use_fast_accum=scaled_mm_config.use_fast_accum,
16451656
).reshape(out_shape)
16461657

1658+
def _linear_fp_act_fp8_weight_check(
1659+
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
1660+
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
1661+
bias: Optional[torch.Tensor],
1662+
) -> bool:
1663+
return (
1664+
# input is native float tensor
1665+
not is_traceable_wrapper_subclass(input_tensor) and
1666+
input_tensor.is_floating_point() and
1667+
# weight is float8 quantized affine quantized tensor
1668+
isinstance(weight_tensor, AffineQuantizedTensor) and
1669+
isinstance(weight_tensor.layout_type, Float8LayoutType)
1670+
and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
1671+
and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor))
1672+
)
1673+
1674+
def _linear_fp_act_fp8_weight_impl(
1675+
input_tensor: torch.Tensor,
1676+
weight_tensor: AffineQuantizedTensor,
1677+
bias: Optional[torch.Tensor],
1678+
):
1679+
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)
16471680

16481681
def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
16491682
return (
@@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches():
16941727
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
16951728
(_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl),
16961729
(_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl),
1730+
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
16971731
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
16981732
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
16991733
(_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),

0 commit comments

Comments
 (0)