Skip to content

Commit 343c3ad

Browse files
committed
[RFC] Enable HSDP + CP
Summary: This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP ghstack-source-id: 8f42a91 Pull Request resolved: #463
1 parent 3f3bc38 commit 343c3ad

File tree

3 files changed

+31
-14
lines changed

3 files changed

+31
-14
lines changed

torchtitan/parallelisms/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParallelDims:
3030
world_size: int
3131
enable_loss_parallel: bool
3232
dp_type: str
33+
dp_replicate: int
3334

3435
def __post_init__(self):
3536
self.dp_type = self.dp_type.lower()
@@ -40,20 +41,21 @@ def _validate(self):
4041
if dp == -1:
4142
self.dp = dp = self.world_size // (cp * tp * pp)
4243
assert dp >= 1, dp
44+
assert dp % self.dp_replicate, (self.dp_replicate, dp)
4345
assert cp >= 1, cp
4446
assert tp >= 1, tp
4547
assert pp >= 1, pp
4648
assert dp * cp * tp * pp == self.world_size, (
4749
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
4850
f"!= WORLD_SIZE({self.world_size})"
4951
)
50-
assert self.dp_type in ("fsdp", "ddp")
52+
assert self.dp_type in ("fsdp", "ddp", "hsdp")
5153

5254
def build_mesh(self, device_type):
5355
dims = []
5456
names = []
5557
for d, name in zip(
56-
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
58+
[self.pp, self.dp * self.cp, self.tp], ["pp", "dp", "tp"], strict=True
5759
):
5860
if d > 1:
5961
dims.append(d)

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
2828
checkpoint_wrapper as ptd_checkpoint_wrapper,
2929
)
30-
from torch.distributed.device_mesh import init_device_mesh
3130
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
3231
from torch.distributed.tensor.parallel import (
3332
ColwiseParallel,
@@ -463,7 +462,10 @@ def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
463462
"""
464463
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
465464
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
466-
cp_mesh = world_mesh["cp"]
465+
dp_mesh = world_mesh["dp"]
466+
cp_mesh = dp_mesh.reshape(
467+
(dp_mesh.size() // parallel_dims.cp, parallel_dims.cp), ("dp", "cp")
468+
)
467469
callers = []
468470
for layer_id, transformer_block in model.layers.items():
469471
callers.append(transformer_block.attention)
@@ -479,22 +481,21 @@ def apply_fsdp(
479481
parallel_dims: "ParallelDims",
480482
job_config: JobConfig,
481483
):
482-
483484
"""
484485
Apply data parallelism to the model. FSDP2 is used here.
485486
"""
486487

487-
if parallel_dims.cp_enabled:
488-
# Temporary solution to enable FSDP + CP
489-
dp_mesh = init_device_mesh(
490-
world_mesh.device_type,
491-
(parallel_dims.dp * parallel_dims.cp,),
492-
mesh_dim_names=["dp"],
493-
)
488+
# This mesh also includes cp degree if it is larger than 1.
489+
if parallel_dims.dp_type == "fsdp":
490+
dp_mesh = world_mesh["dp"]
494491
else:
492+
assert parallel_dims.dp_type == "hsdp", parallel_dims.dp_type
495493
dp_mesh = world_mesh["dp"]
496-
497-
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
494+
dp_mesh = dp_mesh.reshape(
495+
(parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate),
496+
("dp_replicate", "dp_shard"),
497+
)
498+
# assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
498499

499500
mp_policy = MixedPrecisionPolicy(
500501
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],

train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.distributed.checkpoint.stateful import Stateful
2626
from torch.distributed.elastic.multiprocessing.errors import record
2727
from torch.distributed.tensor.parallel import loss_parallel
28+
from torch.distributed.utils import _sync_module_states_with_mesh
2829

2930
from torchtitan.checkpoint import CheckpointManager
3031
from torchtitan.config_manager import JobConfig
@@ -186,6 +187,11 @@ def main(job_config: JobConfig):
186187
world_mesh = parallel_dims.build_mesh(device_type="cuda")
187188
if parallel_dims.dp_enabled:
188189
dp_mesh = world_mesh["dp"]
190+
if parallel_dims.cp_enabled:
191+
dp_mesh = dp_mesh.reshape(
192+
(dp_mesh.size() // parallel_dims.cp, parallel_dims.cp)
193+
("dp", "cp")
194+
)["dp"]
189195
dp_degree = dp_mesh.size()
190196
dp_rank = dp_mesh.get_local_rank()
191197
else:
@@ -347,6 +353,14 @@ def loss_fn(pred, labels):
347353
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
348354
)
349355

356+
if not checkpoint_loaded and parallel_dims.dp_enabled and parallel_dims.dp_replicate > 1:
357+
# Sync parameters if HSDP is enabled.
358+
replicate_mesh = dp_mesh.reshape(
359+
(parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate)
360+
)
361+
_sync_module_states_with_mesh(model, replicate_mesh)
362+
363+
350364
# plot losses loaded from checkpoint (if any) to TensorBoard
351365
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
352366
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq

0 commit comments

Comments
 (0)