Skip to content

Commit 41a2ec1

Browse files
committed
[RFC] Enable HSDP + CP
Summary: This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP ghstack-source-id: 189ab55 Pull Request resolved: #463
1 parent 5ba01b8 commit 41a2ec1

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

torchtitan/parallelisms/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParallelDims:
3030
world_size: int
3131
enable_loss_parallel: bool
3232
dp_type: str
33+
dp_replicate: int
3334

3435
def __post_init__(self):
3536
self.dp_type = self.dp_type.lower()
@@ -40,20 +41,21 @@ def _validate(self):
4041
if dp == -1:
4142
self.dp = dp = self.world_size // (cp * tp * pp)
4243
assert dp >= 1, dp
44+
assert dp % self.dp_replicate, (self.dp_replicate, dp)
4345
assert cp >= 1, cp
4446
assert tp >= 1, tp
4547
assert pp >= 1, pp
4648
assert dp * cp * tp * pp == self.world_size, (
4749
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
4850
f"!= WORLD_SIZE({self.world_size})"
4951
)
50-
assert self.dp_type in ("fsdp", "ddp")
52+
assert self.dp_type in ("fsdp", "ddp", "hsdp")
5153

5254
def build_mesh(self, device_type):
5355
dims = []
5456
names = []
5557
for d, name in zip(
56-
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
58+
[self.pp, self.dp * self.cp, self.tp], ["pp", "dp", "tp"], strict=True
5759
):
5860
if d > 1:
5961
dims.append(d)

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
2828
checkpoint_wrapper as ptd_checkpoint_wrapper,
2929
)
30-
from torch.distributed.device_mesh import init_device_mesh
3130
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
3231
from torch.distributed.tensor.parallel import (
3332
ColwiseParallel,
@@ -463,7 +462,8 @@ def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
463462
"""
464463
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
465464
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
466-
cp_mesh = world_mesh["cp"]
465+
dp_mesh = world_mesh["dp"]
466+
cp_mesh = dp_mesh.reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"]
467467
callers = []
468468
for layer_id, transformer_block in model.layers.items():
469469
callers.append(transformer_block.attention)
@@ -479,22 +479,21 @@ def apply_fsdp(
479479
parallel_dims: "ParallelDims",
480480
job_config: JobConfig,
481481
):
482-
483482
"""
484483
Apply data parallelism to the model. FSDP2 is used here.
485484
"""
486485

487-
if parallel_dims.cp_enabled:
488-
# Temporary solution to enable FSDP + CP
489-
dp_mesh = init_device_mesh(
490-
world_mesh.device_type,
491-
(parallel_dims.dp * parallel_dims.cp,),
492-
mesh_dim_names=["dp"],
493-
)
486+
# This mesh also includes cp degree if it is larger than 1.
487+
if parallel_dims.dp_type == "fsdp":
488+
dp_mesh = world_mesh["dp"]
494489
else:
490+
assert parallel_dims.dp_type == "hsdp", parallel_dims.dp_type
495491
dp_mesh = world_mesh["dp"]
496-
497-
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
492+
dp_mesh = dp_mesh.reshape(
493+
(parallel_dims.dp_replicate, -1),
494+
("dp_replicate", "dp_shard"),
495+
)
496+
# assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
498497

499498
mp_policy = MixedPrecisionPolicy(
500499
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],

train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.distributed.checkpoint.stateful import Stateful
2626
from torch.distributed.elastic.multiprocessing.errors import record
2727
from torch.distributed.tensor.parallel import loss_parallel
28+
from torch.distributed.utils import _sync_module_states_with_mesh
2829

2930
from torchtitan.checkpoint import CheckpointManager
3031
from torchtitan.config_manager import JobConfig
@@ -186,6 +187,8 @@ def main(job_config: JobConfig):
186187
world_mesh = parallel_dims.build_mesh(device_type="cuda")
187188
if parallel_dims.dp_enabled:
188189
dp_mesh = world_mesh["dp"]
190+
if parallel_dims.cp_enabled:
191+
dp_mesh = dp_mesh.reshape((-1, parallel_dims.cp)("dp", "cp"))["dp"]
189192
dp_degree = dp_mesh.size()
190193
dp_rank = dp_mesh.get_local_rank()
191194
else:
@@ -217,7 +220,7 @@ def main(job_config: JobConfig):
217220
)
218221

219222
if parallel_dims.cp_enabled:
220-
cp_mesh = world_mesh["cp"]
223+
cp_mesh = world_mesh["dp"].reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"]
221224
context_parallel_ctx = partial(
222225
context_parallel_buffers,
223226
cp_rank=cp_mesh.get_local_rank(),
@@ -347,6 +350,17 @@ def loss_fn(pred, labels):
347350
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
348351
)
349352

353+
if (
354+
not checkpoint_loaded
355+
and parallel_dims.dp_enabled
356+
and parallel_dims.dp_replicate > 1
357+
):
358+
# Sync parameters if HSDP is enabled.
359+
replicate_mesh = dp_mesh.reshape(
360+
(parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate)
361+
)
362+
_sync_module_states_with_mesh(model, replicate_mesh)
363+
350364
# plot losses loaded from checkpoint (if any) to TensorBoard
351365
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
352366
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq

0 commit comments

Comments
 (0)