File tree Expand file tree Collapse file tree 3 files changed +8
-4
lines changed Expand file tree Collapse file tree 3 files changed +8
-4
lines changed Original file line number Diff line number Diff line change 39
39
40
40
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
41
41
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
42
+ python -m pip install git+https://github.com/pytorch-labs/float8_experimental.git
42
43
mkdir artifacts-to-be-uploaded
43
44
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
Original file line number Diff line number Diff line change 15
15
import contextlib
16
16
from typing import Optional
17
17
18
- import float8_experimental .config as config
19
-
20
18
import torch
21
19
import torch .nn as nn
22
- from float8_experimental .float8_linear import TensorScalingType
23
20
24
21
from torchtitan .config_manager import JobConfig
25
22
from torchtitan .logging_utils import logger
26
23
27
24
28
25
@contextlib .contextmanager
29
26
def set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather : bool ):
27
+ import float8_experimental .config as config
28
+
30
29
prev = config .enable_fsdp_fp8_all_gather
31
30
torch .distributed .barrier ()
32
31
config .enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
@@ -51,6 +50,7 @@ def build_fp8_linear(
51
50
job_config .training .enable_fsdp_fp8_all_gather and dp_enabled
52
51
)
53
52
try :
53
+ from float8_experimental .float8_linear import TensorScalingType
54
54
from float8_experimental .float8_linear_utils import (
55
55
swap_linear_with_float8_linear ,
56
56
)
Original file line number Diff line number Diff line change 19
19
20
20
import torch
21
21
import torch .nn .functional as F
22
- from float8_experimental .fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
23
22
from torch .distributed import destroy_process_group
24
23
from torch .distributed .checkpoint .stateful import Stateful
25
24
from torch .distributed .elastic .multiprocessing .errors import record
@@ -404,6 +403,10 @@ def loss_fn(pred, labels):
404
403
and job_config .training .enable_fsdp_fp8_all_gather
405
404
and job_config .training .precompute_float8_dynamic_scale_for_fsdp
406
405
):
406
+ from float8_experimental .fsdp_utils import (
407
+ precompute_float8_dynamic_scale_for_fsdp ,
408
+ )
409
+
407
410
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
408
411
# it issues a single all-reduce for all parameters at once for better performance
409
412
precompute_float8_dynamic_scale_for_fsdp (model )
You can’t perform that action at this time.
0 commit comments