Skip to content

Commit 09da3c7

Browse files
y-sqfacebook-github-bot
authored andcommitted
Fix fp8-all-gather buck errors
Differential Revision: D63048850
1 parent 0bdde92 commit 09da3c7

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
1919
from torchao.float8.float8_linear_utils import convert_to_float8_training
2020
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
21-
from fsdp2_common import check_parity_bf16_mp, check_parity_no_mp
2221
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
2322
from torch.distributed._tensor import DTensor
2423
from torch.testing._internal.common_cuda import TEST_CUDA
@@ -36,6 +35,14 @@
3635
TransformerBlock,
3736
)
3837

38+
# OSS and fbcode need different import statements
39+
# TODO: fix the issue and remove the try-except block.
40+
try:
41+
from test_fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
42+
except ImportError:
43+
from .test_fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
44+
45+
3946
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
4047
if not is_cuda_8_9:
4148
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)

test/float8/test_fsdp2/fsdp2_common.py renamed to test/float8/test_fsdp2/test_fsdp2_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def check_parity_no_mp(
4949
precompute_float8_dynamic_scale_for_fsdp(model)
5050

5151
if compile_transformer_block:
52-
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
52+
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
5353
else:
54-
test_cls.assertEqual(losses[0], losses[1])
54+
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
5555

5656

5757
def check_parity_bf16_mp(
@@ -86,4 +86,4 @@ def check_parity_bf16_mp(
8686
ref_model.parameters(), ref_model_bf16.parameters()
8787
):
8888
param_bf16.detach().copy_(param_fp32)
89-
test_cls.assertEqual(losses[0], losses[1])
89+
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")

0 commit comments

Comments
 (0)