Skip to content

Commit 9fa6105

Browse files
author
Sanket Jayant Purandare
committed
Re-enable FSDP2 Mem Tracker integration tests
ghstack-source-id: 8344603 Pull Request resolved: #485
1 parent 00a3c21 commit 9fa6105

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

estimation.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
from torch._subclasses.fake_tensor import FakeTensorMode
1515
from torch.distributed import destroy_process_group
1616
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17-
from torch.distributed.tensor.parallel import loss_parallel
1817
from torch.testing._internal.distributed.fake_pg import FakeStore
1918

2019
from torchtitan.config_manager import JobConfig
2120
from torchtitan.datasets import create_tokenizer
22-
from torchtitan.float8_linear import build_fp8_linear
21+
from torchtitan.float8_linear import (
22+
maybe_build_fp8_linear,
23+
maybe_precompute_fp8_dynamic_scale_for_fsdp,
24+
)
2325
from torchtitan.logging_utils import init_logger, logger
2426
from torchtitan.lr_scheduling import get_lr_schedulers
2527
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2628
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
27-
from train import build_optimizers
29+
from train import build_optimizers, get_train_context
2830

2931

3032
def estimate_memory(job_config: JobConfig):
@@ -61,9 +63,10 @@ def estimate_memory(job_config: JobConfig):
6163
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
6264
job_config.model.norm_type = "rmsnorm"
6365

64-
if job_config.training.compile:
66+
if job_config.training.compile or job_config.experimental.enable_compiled_autograd:
6567
logger.info("Compile mode is not supported yet. Switching to eager mode.")
6668
job_config.training.compile = False
69+
job_config.experimental.enable_compiled_autograd = False
6770

6871
parallel_dims = ParallelDims(
6972
dp=job_config.training.data_parallel_degree,
@@ -96,9 +99,9 @@ def estimate_memory(job_config: JobConfig):
9699
tokenizer_type = model_name_to_tokenizer[model_name]
97100
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
98101

99-
# loss_parallel enables dispatching to efficient loss operators
100-
loss_parallel_ctx = (
101-
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
102+
train_context = get_train_context(
103+
parallel_dims.loss_parallel_enabled,
104+
job_config.experimental.enable_compiled_autograd,
102105
)
103106

104107
# loss fn can be shared by pipeline-parallel or non-pp execution
@@ -124,9 +127,8 @@ def loss_fn(pred, labels):
124127
with torch.device("meta"):
125128
whole_model = model_cls.from_model_args(model_config)
126129

127-
# apply fp8 linear module swap
128-
if job_config.training.enable_fp8_linear:
129-
build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
130+
# swap to Float8Linear base on fp8 config
131+
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
130132

131133
# apply PT-D DP/TP parallelisms and activation checkpointing
132134
model_parts = [whole_model]
@@ -171,7 +173,7 @@ def loss_fn(pred, labels):
171173
for iter_idx in range(2):
172174
input_ids, labels = batch
173175
# train step
174-
with loss_parallel_ctx():
176+
with train_context():
175177
pred = whole_model(input_ids)
176178
loss = loss_fn(pred, labels)
177179
del pred
@@ -185,6 +187,10 @@ def loss_fn(pred, labels):
185187
# optimizer step
186188
optimizers.step()
187189
lr_schedulers.step()
190+
# when fp8 config is on,
191+
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
192+
# it issues a single all-reduce for all parameters at once for better performance
193+
maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config)
188194
optimizers.zero_grad()
189195
print(f"Peak Memory at iter: {iter_idx}")
190196
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)

test_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
314314

315315
for override_arg in test_flavor.override_args:
316316
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
317+
if test_name == "fsdp2_mem_tracker":
318+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_memory_estimation.sh"
317319
cmd += " " + dump_folder_arg
318320
cmd += " " + model_flavor_arg
319321
if override_arg:

0 commit comments

Comments
 (0)