Skip to content

Commit 449a359

Browse files
committed
Add support of DDP and experimental CompiledAutograd
Summary: Address the comments in #319 and resubmit the PR to fit the current code base. Test Plan: ``` CONFIG_FILE=./train_configs/debug_model.toml ./run_llama_train.sh --comm.train_timeout_seconds=3600 --training.tensor_parallel_degree=1 --training.data_parallel_degree=8 --experimental.data_parallel_type=ddp --training.steps=1000 --metrics.log_freq=10 --profiling.profile_freq=1000 ``` ghstack-source-id: 18bb9b5 Pull Request resolved: #432
1 parent 0bb6980 commit 449a359

File tree

6 files changed

+81
-8
lines changed

6 files changed

+81
-8
lines changed

estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def estimate_memory(job_config: JobConfig):
7171
pp=job_config.experimental.pipeline_parallel_degree,
7272
world_size=world_size,
7373
enable_loss_parallel=job_config.training.enable_loss_parallel,
74+
dp_type=job_config.training.data_parallel_type,
7475
)
7576

7677
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")

test_runner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,17 @@ def build_test_list():
273273
"fsdp2_mem_tracker",
274274
ngpu=4,
275275
),
276+
OverrideDefinitions(
277+
[
278+
[
279+
"--training.data_parallel_type ddp",
280+
"--experimental.enable_compiled_autograd",
281+
]
282+
],
283+
"CompiledDDP",
284+
"compiled_ddp",
285+
ngpu=4,
286+
),
276287
]
277288
return integration_tests_flavors
278289

torchtitan/config_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ def __init__(self):
312312
The default value will be the number of pipeline stages, if unspecified.
313313
""",
314314
)
315+
self.parser.add_argument(
316+
"--training.data_parallel_type",
317+
type=str,
318+
default="fsdp",
319+
help="Data parallelism type. TorchTitan currently supports FSDP and DDP.",
320+
)
321+
self.parser.add_argument(
322+
"--experimental.enable_compiled_autograd",
323+
action="store_true",
324+
help="Enable CompiledAutograd to compile the backward.",
325+
)
315326
self.parser.add_argument(
316327
"--training.mixed_precision_param",
317328
type=str,

torchtitan/parallelisms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class ParallelDims:
2828
pp: int
2929
world_size: int
3030
enable_loss_parallel: bool
31+
dp_type: str
3132

3233
def __post_init__(self):
34+
self.dp_type = self.dp_type.lower()
3335
self._validate()
3436

3537
def _validate(self):
@@ -42,6 +44,7 @@ def _validate(self):
4244
assert (
4345
dp * tp * pp == self.world_size
4446
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
47+
assert self.dp_type in ("fsdp", "ddp")
4548

4649
def build_mesh(self, device_type):
4750
dims = []

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from torch.distributed import DeviceMesh
1717

1818
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
19+
20+
from torch.distributed._composable.replicate import replicate
1921
from torch.distributed._tensor import Replicate, Shard
2022
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
2123
checkpoint_wrapper as ptd_checkpoint_wrapper,
@@ -449,13 +451,15 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
449451
return model
450452

451453

452-
def apply_dp(
454+
def apply_fsdp(
453455
model: nn.Module,
454456
world_mesh: DeviceMesh,
455457
parallel_dims: "ParallelDims",
456458
job_config: JobConfig,
457459
):
458-
"""Apply data parallelism (FSDP2) to the model."""
460+
"""
461+
Apply data parallelism to the model. FSDP2 is used here.
462+
"""
459463

460464
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
461465
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
@@ -488,6 +492,29 @@ def apply_dp(
488492
return model
489493

490494

495+
def apply_ddp(
496+
model: nn.Module,
497+
world_mesh: DeviceMesh,
498+
parallel_dims: "ParallelDims",
499+
job_config: JobConfig,
500+
):
501+
if world_mesh.ndim > 1:
502+
raise RuntimeError("DDP has not supported > 1D parallelism.")
503+
504+
if job_config.training.compile:
505+
if job_config.experimental.enable_compiled_autograd:
506+
torch._dynamo.config.optimize_ddp = (
507+
"python_reducer_without_compiled_forward"
508+
)
509+
else:
510+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
511+
512+
model = replicate(model, device_mesh=world_mesh, bucket_cap_mb=100)
513+
514+
logger.info("Applied DDP to the model")
515+
return model
516+
517+
491518
def parallelize_llama(
492519
model: nn.Module,
493520
world_mesh: DeviceMesh,
@@ -512,6 +539,9 @@ def parallelize_llama(
512539
model = apply_compile(model, job_config)
513540

514541
if parallel_dims.dp_enabled:
515-
model = apply_dp(model, world_mesh, parallel_dims, job_config)
542+
if parallel_dims.dp_type == "fsdp":
543+
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)
544+
else:
545+
model = apply_ddp(model, world_mesh, parallel_dims, job_config)
516546

517547
return model

train.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,22 @@ def zero_grad(self):
135135
return OptimizersContainer([_build_optimizer(model) for model in model_parts])
136136

137137

138+
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
139+
@contextlib.contextmanager
140+
def context():
141+
with contextlib.ExitStack() as stack:
142+
if enable_loss_parallel:
143+
stack.enter_context(loss_parallel())
144+
if enable_compiled_autograd:
145+
stack.enter_context(
146+
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
147+
)
148+
149+
yield
150+
151+
return context
152+
153+
138154
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
139155
@record
140156
def main(job_config: JobConfig):
@@ -157,6 +173,7 @@ def main(job_config: JobConfig):
157173
pp=job_config.experimental.pipeline_parallel_degree,
158174
world_size=world_size,
159175
enable_loss_parallel=job_config.training.enable_loss_parallel,
176+
dp_type=job_config.training.data_parallel_type,
160177
)
161178
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
162179
torch.cuda.set_device(device)
@@ -191,9 +208,9 @@ def main(job_config: JobConfig):
191208
dp_rank,
192209
)
193210

194-
# loss_parallel enables dispatching to efficient loss operators
195-
loss_parallel_ctx = (
196-
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
211+
train_context = get_train_context(
212+
parallel_dims.loss_parallel_enabled,
213+
job_config.experimental.enable_compiled_autograd,
197214
)
198215

199216
# loss fn can be shared by pipeline-parallel or non-pp execution
@@ -362,7 +379,7 @@ def loss_fn(pred, labels):
362379
# pipeline parallel forward / backward inside step() call
363380
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
364381

365-
with loss_parallel_ctx():
382+
with train_context():
366383
if pp_mesh.get_local_rank() == 0:
367384
pp_schedule.step(input_ids)
368385
elif is_last_stage:
@@ -379,7 +396,7 @@ def loss_fn(pred, labels):
379396
)
380397
else:
381398
# Non-PP forward / backward
382-
with loss_parallel_ctx():
399+
with train_context():
383400
pred = model(input_ids)
384401
loss = loss_fn(pred, labels)
385402
# pred.shape=(bs, seq_len, vocab_size)

0 commit comments

Comments
 (0)