14
14
from torch ._subclasses .fake_tensor import FakeTensorMode
15
15
from torch .distributed import destroy_process_group
16
16
from torch .distributed ._tools .fsdp2_mem_tracker import FSDPMemTracker
17
- from torch .distributed .tensor .parallel import loss_parallel
18
17
from torch .testing ._internal .distributed .fake_pg import FakeStore
19
18
20
19
from torchtitan .config_manager import JobConfig
21
20
from torchtitan .datasets import create_tokenizer
22
- from torchtitan .float8_linear import build_fp8_linear
21
+ from torchtitan .float8_linear import (
22
+ maybe_build_fp8_linear ,
23
+ maybe_precompute_fp8_dynamic_scale_for_fsdp ,
24
+ )
23
25
from torchtitan .logging_utils import init_logger , logger
24
26
from torchtitan .lr_scheduling import get_lr_schedulers
25
27
from torchtitan .models import model_name_to_cls , model_name_to_tokenizer , models_config
26
28
from torchtitan .parallelisms import models_parallelize_fns , ParallelDims
27
- from train import build_optimizers
29
+ from train import build_optimizers , get_train_context
28
30
29
31
30
32
def estimate_memory (job_config : JobConfig ):
@@ -61,9 +63,10 @@ def estimate_memory(job_config: JobConfig):
61
63
logger .info ("Compiled RMSNorm is not supported yet. Switching to RMSNorm." )
62
64
job_config .model .norm_type = "rmsnorm"
63
65
64
- if job_config .training .compile :
66
+ if job_config .training .compile or job_config . experimental . enable_compiled_autograd :
65
67
logger .info ("Compile mode is not supported yet. Switching to eager mode." )
66
68
job_config .training .compile = False
69
+ job_config .experimental .enable_compiled_autograd = False
67
70
68
71
parallel_dims = ParallelDims (
69
72
dp = job_config .training .data_parallel_degree ,
@@ -96,9 +99,9 @@ def estimate_memory(job_config: JobConfig):
96
99
tokenizer_type = model_name_to_tokenizer [model_name ]
97
100
tokenizer = create_tokenizer (tokenizer_type , job_config .model .tokenizer_path )
98
101
99
- # loss_parallel enables dispatching to efficient loss operators
100
- loss_parallel_ctx = (
101
- loss_parallel if parallel_dims . loss_parallel_enabled else contextlib . nullcontext
102
+ train_context = get_train_context (
103
+ parallel_dims . loss_parallel_enabled ,
104
+ job_config . experimental . enable_compiled_autograd ,
102
105
)
103
106
104
107
# loss fn can be shared by pipeline-parallel or non-pp execution
@@ -124,9 +127,8 @@ def loss_fn(pred, labels):
124
127
with torch .device ("meta" ):
125
128
whole_model = model_cls .from_model_args (model_config )
126
129
127
- # apply fp8 linear module swap
128
- if job_config .training .enable_fp8_linear :
129
- build_fp8_linear (whole_model , job_config , parallel_dims .dp_enabled )
130
+ # swap to Float8Linear base on fp8 config
131
+ maybe_build_fp8_linear (whole_model , job_config , parallel_dims .dp_enabled )
130
132
131
133
# apply PT-D DP/TP parallelisms and activation checkpointing
132
134
model_parts = [whole_model ]
@@ -171,7 +173,7 @@ def loss_fn(pred, labels):
171
173
for iter_idx in range (2 ):
172
174
input_ids , labels = batch
173
175
# train step
174
- with loss_parallel_ctx ():
176
+ with train_context ():
175
177
pred = whole_model (input_ids )
176
178
loss = loss_fn (pred , labels )
177
179
del pred
@@ -185,6 +187,10 @@ def loss_fn(pred, labels):
185
187
# optimizer step
186
188
optimizers .step ()
187
189
lr_schedulers .step ()
190
+ # when fp8 config is on,
191
+ # calculate float8 dynamic amax/scale for all-parameter for FSDP2
192
+ # it issues a single all-reduce for all parameters at once for better performance
193
+ maybe_precompute_fp8_dynamic_scale_for_fsdp (whole_model , job_config )
188
194
optimizers .zero_grad ()
189
195
print (f"Peak Memory at iter: { iter_idx } " )
190
196
fsdp_memtracker .display_snapshot ("peak" , units = "MiB" , tabulate = True )
0 commit comments