Skip to content

Commit bcbec9e

Browse files
committed
Added op support for float8
1 parent 9229df9 commit bcbec9e

File tree

3 files changed

+62
-19
lines changed

3 files changed

+62
-19
lines changed
Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
22
from torch.testing._internal.common_utils import run_tests
3-
from torchao.quantization import int8_weight_only
3+
from torchao.quantization import int8_weight_only, float8_weight_only
44

5-
class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
6-
pass
5+
# class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
6+
# pass
77

8+
class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
9+
QUANT_METHOD_FN = staticmethod(float8_weight_only)
810

9-
copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")
11+
print('Copy test started...')
12+
copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "fp8wo_tp")
13+
print('Copy test finished')
1014

1115
if __name__ == "__main__":
16+
print("Running TestAffineQuantizedTensorParallel")
1217
run_tests()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
10941094
return return_and_correct_aliasing(
10951095
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
10961096
)
1097-
if func is aten.clone.default:
1097+
elif func is aten.clone.default:
10981098
return return_and_correct_aliasing(
10991099
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
11001100
)
1101-
if func is aten.t.default:
1101+
elif func is aten.t.default:
11021102
"""we don't need to repack the weight and just rely on external
11031103
shape being changed and record the status of transpose/no-transpose
11041104
"""
11051105
args[0].transposed = not args[0].transposed
11061106
return return_and_correct_aliasing(func, args, kwargs, args[0])
1107-
1108-
raise NotImplementedError(
1109-
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
1110-
)
1107+
elif func is aten.slice.Tensor:
1108+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
1109+
if dim == 0:
1110+
return return_and_correct_aliasing(
1111+
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
1112+
)
1113+
elif dim == 1:
1114+
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
1115+
return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type)
1116+
else:
1117+
raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
1118+
else:
1119+
raise NotImplementedError(
1120+
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
1121+
)
11111122

11121123
__torch_function__ = torch._C._disabled_torch_function_impl
11131124

@@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl(
16441655
use_fast_accum=scaled_mm_config.use_fast_accum,
16451656
).reshape(out_shape)
16461657

1658+
def _linear_fp_act_fp8_weight_check(
1659+
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
1660+
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
1661+
bias: Optional[torch.Tensor],
1662+
) -> bool:
1663+
return (
1664+
# input is native float tensor
1665+
not is_traceable_wrapper_subclass(input_tensor) and
1666+
input_tensor.is_floating_point() and
1667+
# weight is float8 quantized affine quantized tensor
1668+
isinstance(weight_tensor, AffineQuantizedTensor) and
1669+
isinstance(weight_tensor.layout_type, Float8LayoutType)
1670+
and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
1671+
and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor))
1672+
)
1673+
1674+
def _linear_fp_act_fp8_weight_impl(
1675+
input_tensor: torch.Tensor,
1676+
weight_tensor: AffineQuantizedTensor,
1677+
bias: Optional[torch.Tensor],
1678+
):
1679+
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)
16471680

16481681
def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
16491682
return (
@@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches():
16941727
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
16951728
(_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl),
16961729
(_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl),
1730+
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
16971731
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
16981732
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
16991733
(_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),

torchao/testing/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -285,49 +285,53 @@ def test_tp(self, dtype):
285285
device = "cuda"
286286
# To make sure different ranks create the same module
287287
torch.manual_seed(5)
288-
288+
print('Step 1')
289289
class M(torch.nn.Module):
290290
def __init__(self, in_features, out_features, **kwargs) -> None:
291291
super().__init__(**kwargs)
292292
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
293293

294294
def forward(self, x: torch.Tensor) -> torch.Tensor:
295295
return self.linear(x)
296-
296+
print('Step 2')
297297
# Get rank and device
298298
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
299-
299+
print('Step 3')
300300
# Original model
301301
proj_up = M(1024, 2048).to(device).to(dtype)
302302
proj_dn = M(2048, 1024).to(device).to(dtype)
303303
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
304304
y = proj_dn(proj_up(example_input))
305-
305+
print('Step 4')
306306
# Quantize the model
307307
up_quant = self.quantize(proj_up)
308308
dn_quant = self.quantize(proj_dn)
309309
y_q = dn_quant(up_quant(example_input))
310-
310+
print('Step 5')
311311
mesh = self.build_device_mesh()
312312
# Shard the models
313313
up_dist = self.colwise_shard(up_quant, mesh)
314314
dn_dist = self.rowwise_shard(dn_quant, mesh)
315-
315+
print('Step 6')
316316
# We need to turn inputs into DTensor form as well -- just a format change
317317
input_dtensor = DTensor.from_local(
318318
example_input, mesh, [Replicate()]
319319
)
320-
320+
print('Step 7')
321321
y_d = dn_dist(up_dist(input_dtensor))
322-
322+
print('Step 8')
323323
if not TORCH_VERSION_AT_LEAST_2_5:
324324
# Need torch 2.5 to support compiled tensor parallelism
325325
return
326-
326+
print('Step 9')
327327
up_compiled = torch.compile(up_dist)
328+
print('Step 10')
328329
y_up = up_compiled(input_dtensor)
330+
print('Step 11')
329331
dn_compiled = torch.compile(dn_dist)
332+
print('Step 12')
330333
y_dn = dn_compiled(y_up)
334+
print('Step 13')
331335

332336
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
333337
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)

0 commit comments

Comments
 (0)