Skip to content

Commit 298494b

Browse files
authored
import float8_experimental only when fp8 is enabled and install it in CI (#464)
make sure to only import float8_experimental when fp8 is enabled for 4 gpu CI, make sure we can import float8_experimental correctly in CI `python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git`
1 parent f025335 commit 298494b

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

.github/workflows/integration_test_4gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ jobs:
3939
4040
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
4141
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42+
python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
4243
mkdir artifacts-to-be-uploaded
4344
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4

torchtitan/float8_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@
1515
import contextlib
1616
from typing import Optional
1717

18-
import float8_experimental.config as config
19-
2018
import torch
2119
import torch.nn as nn
22-
from float8_experimental.float8_linear import TensorScalingType
2320

2421
from torchtitan.config_manager import JobConfig
2522
from torchtitan.logging_utils import logger
2623

2724

2825
@contextlib.contextmanager
2926
def set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather: bool):
27+
import float8_experimental.config as config
28+
3029
prev = config.enable_fsdp_fp8_all_gather
3130
torch.distributed.barrier()
3231
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
@@ -51,6 +50,7 @@ def build_fp8_linear(
5150
job_config.training.enable_fsdp_fp8_all_gather and dp_enabled
5251
)
5352
try:
53+
from float8_experimental.float8_linear import TensorScalingType
5454
from float8_experimental.float8_linear_utils import (
5555
swap_linear_with_float8_linear,
5656
)

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import torch
2121
import torch.nn.functional as F
22-
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
2322
from torch.distributed import destroy_process_group
2423
from torch.distributed.checkpoint.stateful import Stateful
2524
from torch.distributed.elastic.multiprocessing.errors import record
@@ -404,6 +403,10 @@ def loss_fn(pred, labels):
404403
and job_config.training.enable_fsdp_fp8_all_gather
405404
and job_config.training.precompute_float8_dynamic_scale_for_fsdp
406405
):
406+
from float8_experimental.fsdp_utils import (
407+
precompute_float8_dynamic_scale_for_fsdp,
408+
)
409+
407410
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
408411
# it issues a single all-reduce for all parameters at once for better performance
409412
precompute_float8_dynamic_scale_for_fsdp(model)

0 commit comments

Comments
 (0)