From 771675a9876097c8d4d3d3c88673379a0cc6b756 Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Wed, 25 Sep 2024 12:21:14 -0700 Subject: [PATCH] Fix fp8-all-gather buck errors (#912) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/912 Reviewed By: vkuzo Differential Revision: D63048850 --- test/float8/test_fsdp2/test_fsdp2.py | 2 +- torchao/testing/__init__.py | 0 torchao/testing/float8/__init__.py | 0 .../testing/float8/fsdp2_utils.py | 6 +++--- 4 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 torchao/testing/__init__.py create mode 100644 torchao/testing/float8/__init__.py rename test/float8/test_fsdp2/fsdp2_common.py => torchao/testing/float8/fsdp2_utils.py (90%) diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index e2e7097f6b..ecde051e36 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -18,7 +18,7 @@ from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import DTensor from torch.testing._internal.common_cuda import TEST_CUDA diff --git a/torchao/testing/__init__.py b/torchao/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/float8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/float8/test_fsdp2/fsdp2_common.py b/torchao/testing/float8/fsdp2_utils.py similarity index 90% rename from test/float8/test_fsdp2/fsdp2_common.py rename to torchao/testing/float8/fsdp2_utils.py index 333206ba41..f558bb11f9 100644 --- a/test/float8/test_fsdp2/fsdp2_common.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -49,9 +49,9 @@ def check_parity_no_mp( precompute_float8_dynamic_scale_for_fsdp(model) if compile_transformer_block: - test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4) + test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") else: - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") def check_parity_bf16_mp( @@ -86,4 +86,4 @@ def check_parity_bf16_mp( ref_model.parameters(), ref_model_bf16.parameters() ): param_bf16.detach().copy_(param_fp32) - test_cls.assertEqual(losses[0], losses[1]) + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")