Skip to content

[RFC] Enable HSDP + CP #463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ParallelDims:
world_size: int
enable_loss_parallel: bool
dp_type: str
dp_replicate: int

def __post_init__(self):
self.dp_type = self.dp_type.lower()
Expand All @@ -40,20 +41,21 @@ def _validate(self):
if dp == -1:
self.dp = dp = self.world_size // (cp * tp * pp)
assert dp >= 1, dp
assert dp % self.dp_replicate, (self.dp_replicate, dp)
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert dp * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
f"!= WORLD_SIZE({self.world_size})"
)
assert self.dp_type in ("fsdp", "ddp")
assert self.dp_type in ("fsdp", "ddp", "hsdp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
[self.pp, self.dp * self.cp, self.tp], ["pp", "dp", "tp"], strict=True
):
if d > 1:
dims.append(d)
Expand Down
23 changes: 11 additions & 12 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand Down Expand Up @@ -467,7 +466,8 @@ def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
cp_mesh = world_mesh["cp"]
dp_mesh = world_mesh["dp"]
cp_mesh = dp_mesh.reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"]
Copy link
Contributor Author

@fegin fegin Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may extend ParallelDims to have helpers to get different submeshs. However, the naming can be tricky as there are multiple meaning of dp_mesh.

callers = []
for layer_id, transformer_block in model.layers.items():
callers.append(transformer_block.attention)
Expand All @@ -483,22 +483,21 @@ def apply_fsdp(
parallel_dims: "ParallelDims",
job_config: JobConfig,
):

"""
Apply data parallelism to the model. FSDP2 is used here.
"""

if parallel_dims.cp_enabled:
# Temporary solution to enable FSDP + CP
dp_mesh = init_device_mesh(
world_mesh.device_type,
(parallel_dims.dp * parallel_dims.cp,),
mesh_dim_names=["dp"],
)
# This mesh also includes cp degree if it is larger than 1.
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"]
else:
assert parallel_dims.dp_type == "hsdp", parallel_dims.dp_type
dp_mesh = world_mesh["dp"]

assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
dp_mesh = dp_mesh.reshape(
(parallel_dims.dp_replicate, -1),
("dp_replicate", "dp_shard"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For HSDP + CP, suppose it's a (2, 2, 2) mesh where HSDP + CP applied, how could CP be used together with HSDP? is CP somewhat need to be merge into one of the HSDP dimensions? i.e. HSDP mesh would be (2, 2), and after merging CP it would become (2, 4) or (4, 2)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be (2, 4) for fully_shard and it would be (4,) (the merged dimension of first two dimensions of the world_mesh) for data loader.

)
# assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
Expand Down
16 changes: 15 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel
from torch.distributed.utils import _sync_module_states_with_mesh

from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
Expand Down Expand Up @@ -189,6 +190,8 @@ def main(job_config: JobConfig):
world_mesh = parallel_dims.build_mesh(device_type="cuda")
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
if parallel_dims.cp_enabled:
dp_mesh = dp_mesh.reshape((-1, parallel_dims.cp)("dp", "cp"))["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
Expand Down Expand Up @@ -220,7 +223,7 @@ def main(job_config: JobConfig):
)

if parallel_dims.cp_enabled:
cp_mesh = world_mesh["cp"]
cp_mesh = world_mesh["dp"].reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"]
context_parallel_ctx = partial(
context_parallel_buffers,
cp_rank=cp_mesh.get_local_rank(),
Expand Down Expand Up @@ -349,6 +352,17 @@ def loss_fn(pred, labels):
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
)

if (
not checkpoint_loaded
and parallel_dims.dp_enabled
and parallel_dims.dp_replicate > 1
):
# Sync parameters if HSDP is enabled.
replicate_mesh = dp_mesh.reshape(
(parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate)
)
_sync_module_states_with_mesh(model, replicate_mesh)

# plot losses loaded from checkpoint (if any) to TensorBoard
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
Expand Down
Loading