Skip to content

Commit 71be315

Browse files
weifengpyandrewor14vaishnavi17jainapurvajerryzh168
authored
[float8] improve eager numerics for dynamic scales and gets on par with torch.compile (#904)
* [float8] improve eager numerics for dynamic scales * leave torch.linalg.vector_norm for another PR Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * cuda Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove _data and investigate Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove _data comment Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * upcast to float32 is enough Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * explain why float32 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * _data parity Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * handle sm8.9 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix transformer unit test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * print if error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Add tutorial for trainable tensor subclass (#908) Summary: The new tutorial provides an example of how to implement a trainable tensor subclass that wraps quantized data. This extends the existing `MyDTypeTensor` with a few necessary steps to ensure proper gradient updates, namely: 1. Define a differentiable constructor 2. Define backward pass for ops of interest (e.g. torch.nn.functional.linear) 3. Handle special ops used by the optimizer (e.g. aten.add, aten.add_) Test Plan: python tutorials/developer_api_guide/my_trainable_tensor_subclass.py * Introducing 1-bit quantization for Llama in torchchat (#910) Differential Revision: D63052325 Pull Request resolved: #911 * Rename Floating point to fp8 (#909) * [float8] fix typo in bitwise_identical unit test (#918) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Adding example for quantized tensor + tensor parallelism (#785) * [WIP] Adding example for quantized tensor + tensor parallelism Summary: This PR adds an example of how quantized tensor subclass can work with DTensor: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md End goal is to rewrite https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py with normal llama2 implementation and show case with DTensor + AffineQuantizedTensor + torch.compile we can get on par performance with the custom tensor parallel implementation Test Plan: torchrun --standalone --nnodes=1 --nproc-per-node=4 tutorials/developer_api_guide/tensor_parallel.py Reviewers: Subscribers: Tasks: Tags: * tensor parallel file * Use DTensor.from instead of distribute_tensor * implementing aten.slice.Tensor (WIP) * working * some shape fix and use more quant primitive ops * Add rowwise test * make rowwise sharding work * compile still not working yet * fake tensor didn't pick up shape changes from transpose * backend='eager' * change transpose to non-inplace op * add error message * works now with torch nightly * remove print * ruff * Clean up * Fix device id --------- Co-authored-by: Ke Wen <[email protected]> * rename cuda mode -> gpu mode (#925) * Add workaround to recover the perf for quantized vit in torch.compile (#926) Add temporary workaround to recover the perf for quantized vit under torch.compile Summary: Recently we found a perf drop in quantized vit due to #898 (comment) This PR add a temp fix until we figure out the longer term fix. I think ideally we should figure out why the tensor subclass check failed in torch.compile (https://github.com/pytorch/pytorch/blob/e4d294221b140fdbb49a64f297bc60c9fcc2f80e/torch/nn/modules/activation.py#L1286) and fix that Test Plan: python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags: * clean up device checks in float8 unit test files (#923) Summary: While working on rowwise scaling I noticed that some of the CUDA device capability checks we had in the test files did not make sense, cleaning this up. Test Plan: tests pass on my H100 CI, it should skip less tests now since CI only has CUDA capability 8, 9 Reviewers: Subscribers: Tasks: Tags: * [low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 256 to match new bnb v0.44 (#927) * Float8 autoquant weight only (#866) * Fix failing FP6 benchmark (#931) * Remove two if statements in fp8 padding (#935) Reviewed By: vkuzo Differential Revision: D63051205 Pull Request resolved: #935 Approved by: https://github.com/vkuzo * [Distributed] Improve sharding example (#937) * [Distributed] Improve sharding example * Add comment * Add composable QAT quantizer (#938) Summary: This is a utility for users who wish to apply multiple QAT quantizers to their models. In the near future, we expect to add an embedding QAT quantizer that composes with the existing linear QAT quantizers. Test Plan: python test/quantization/test_qat.py -k test_composable_qat_quantizer * resolve conflict with latest main Differential Revision: D63048850 Pull Request resolved: #912 * Add torchchat quantizer Differential Revision: D62394341 Pull Request resolved: #897 * Add compile tests to test suite (#906) * Add compile tests to test suite Summary: This is a follow up PR addressing #839 (comment) We can add more compiler related tests in the future. Next * refactor a bit to use quantize_ API directly * use the test suite in existing API tests Test Plan: python torchao/testing/utils.py Reviewers: Subscribers: Tasks: Tags: * rename * add result check * Fix up CMakeLists and reorganize some code locations Differential Revision: D62711903 Pull Request resolved: #948 * [float8] all-reduce amax on dp mesh instead of global pg (#933) * [float8] all-reduce amax on dp mesh instead of global pg Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * liner Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * improve comments Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * move hp tensor inside if Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * int8 dynamic quant + bsr support (#821) This PR, adds in int8 dynamicquant + bsr support. Changes: * Use i8i8 -> bf16 matmul to maintain accuracy * Added a block sparse layout type to AffineQuantizedTensor + check/impl. * Cleaned up benchmark.py script and add a single line `benchmark.sh` file for acceleration numbers * Updated eval.py and added a single line `evaluate.sh` file for accuracy numbers * Lots of lint formatting and README updates * torch.compile now working and is correct * fixing some issues with our support for 70/405B models (#941) Summary: download and convert scripts needed to be updated alongside model.py config files Test Plan: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-70B/model.pth Reviewers: Subscribers: Tasks: Tags: * Update INT8 mixed-precision training test to be less flaky (#950) * Add executorch parallel Differential Revision: D62711909 Pull Request resolved: #953 * test CI Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * better comment on why upcasting Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * control seed Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * move unit test to test_compile Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix typo Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * float64 upcasting after allreduce Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * use LinearMMConfig Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: andrewor14 <[email protected]> Co-authored-by: Vaishnavi Gupta <[email protected]> Co-authored-by: Apurva Jain <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Ke Wen <[email protected]> Co-authored-by: Mark Saroufim <[email protected]> Co-authored-by: Vasiliy Kuznetsov <[email protected]> Co-authored-by: Thien Tran <[email protected]> Co-authored-by: Tobias van der Werff <[email protected]> Co-authored-by: Shuqi Yang <[email protected]> Co-authored-by: Scott Roy <[email protected]> Co-authored-by: Jesse Cai <[email protected]> Co-authored-by: HDCharles <[email protected]>
1 parent 1137f39 commit 71be315

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed

test/float8/test_compile.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,15 @@
2525
get_float8_layers,
2626
sync_float8_amax_and_scale_history,
2727
)
28-
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed
29-
from torchao.float8.float8_tensor import LinearMMConfig
28+
from torchao.float8.float8_scaling_utils import (
29+
hp_tensor_to_float8_delayed,
30+
hp_tensor_to_float8_dynamic,
31+
)
32+
from torchao.float8.float8_tensor import (
33+
LinearMMConfig,
34+
GemmInputRole,
35+
ScaledMMConfig,
36+
)
3037
from torchao.float8.float8_utils import e4m3_dtype
3138

3239
from torch._dynamo.test_case import TestCase as DynamoTestCase
@@ -353,5 +360,65 @@ def test_sync_amax_func_cuda_graph_success():
353360
assert "skipping cudagraphs due to mutaton on input" not in stderr[0]
354361

355362

363+
@unittest.skipIf(
364+
not is_cuda_8_9,
365+
"CUDA not available",
366+
)
367+
@pytest.mark.parametrize(
368+
"dtype",
369+
[
370+
torch.float32,
371+
torch.bfloat16,
372+
torch.float16,
373+
],
374+
)
375+
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
376+
scaling_type_weight = ScalingType.DYNAMIC
377+
torch.manual_seed(42)
378+
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
379+
hp_tensor2 = hp_tensor1.detach().clone()
380+
float8_config = Float8LinearConfig(
381+
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
382+
)
383+
linear_mm_config = LinearMMConfig(
384+
# output
385+
ScaledMMConfig(
386+
False,
387+
float8_config.gemm_config_output.use_fast_accum,
388+
False,
389+
float8_config.pad_inner_dim,
390+
),
391+
# grad_input
392+
ScaledMMConfig(
393+
False,
394+
float8_config.gemm_config_grad_input.use_fast_accum,
395+
False,
396+
float8_config.pad_inner_dim,
397+
),
398+
# grad_weight
399+
ScaledMMConfig(
400+
False,
401+
float8_config.gemm_config_grad_weight.use_fast_accum,
402+
False,
403+
float8_config.pad_inner_dim,
404+
),
405+
)
406+
float8_eager = hp_tensor_to_float8_dynamic(
407+
hp_tensor1,
408+
torch.float8_e4m3fn,
409+
linear_mm_config,
410+
gemm_input_role=GemmInputRole.WEIGHT,
411+
)
412+
torch._dynamo.reset()
413+
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
414+
hp_tensor2,
415+
torch.float8_e4m3fn,
416+
linear_mm_config,
417+
gemm_input_role=GemmInputRole.WEIGHT,
418+
)
419+
assert torch.equal(float8_eager._scale, float8_compile._scale)
420+
assert torch.equal(float8_eager._data, float8_compile._data)
421+
422+
356423
if __name__ == "__main__":
357424
pytest.main([__file__])

torchao/float8/float8_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ def forward(
163163
164164
DTensor Invariant: DTensor must always be the outer most tensor subclass
165165
"""
166-
tensor_scaled = tensor * scale
166+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
167+
# upcasted to `float32` to multiply with the scale
168+
# In order to match numerics between eager and compile, we upcast manually here.
169+
tensor_scaled = tensor.to(torch.float32) * scale
167170
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
168171

169172
if isinstance(bits_fp8, DTensor):

torchao/float8/float8_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def amax_to_scale(
4242
float8_dtype: The float8 dtype.
4343
orig_dtype: The original dtype of the tensor.
4444
"""
45+
# torch.compile and eager show different numerics for 1.0 / float32,
46+
# upcast to float64 to ensure same numeric between compile and eager
47+
amax = amax.to(torch.float64)
4548
if float8_dtype in FP8_TYPES:
4649
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
4750
else:

torchao/float8/fsdp_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,17 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
6464
# clamp is dispatched through DTensor
6565
# it will issue a single all-reduce
6666
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
67+
# keep consistent with float8_utils.amax_to_scale
68+
# torch.compile and eager show different numerics for 1.0 / float32,
69+
# upcast to float64 to ensure same numeric between compile and eager
70+
origin_dtype = amax_tensor.dtype
71+
amax_tensor = amax_tensor.to(torch.float64)
6772
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
68-
if amax_tensor.dtype is torch.float16:
73+
if origin_dtype is torch.float16:
6974
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
70-
local_scale_tensor = scale_tensor.to_local()
75+
local_scale_tensor = scale_tensor.to_local().to(torch.float32)
7176
for i, float8_linear in enumerate(float8_linears):
72-
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32)
77+
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
7378

7479

7580
# FSDP pads its local tensor on dim-0. The subclass should be preserved such

torchao/testing/float8/fsdp2_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def check_parity_no_mp(
4848
):
4949
precompute_float8_dynamic_scale_for_fsdp(model)
5050

51-
if compile_transformer_block:
52-
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
53-
else:
54-
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
51+
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
5552

5653

5754
def check_parity_bf16_mp(

0 commit comments

Comments
 (0)