-
Notifications
You must be signed in to change notification settings - Fork 386
[RFC] Enable HSDP + CP #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ | |
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
checkpoint_wrapper as ptd_checkpoint_wrapper, | ||
) | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
|
@@ -467,7 +466,8 @@ def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig): | |
""" | ||
if parallel_dims.tp_enabled or parallel_dims.pp_enabled: | ||
raise NotImplementedError("CP + TP or CP + PP are not supported yet.") | ||
cp_mesh = world_mesh["cp"] | ||
dp_mesh = world_mesh["dp"] | ||
cp_mesh = dp_mesh.reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"] | ||
callers = [] | ||
for layer_id, transformer_block in model.layers.items(): | ||
callers.append(transformer_block.attention) | ||
|
@@ -483,22 +483,21 @@ def apply_fsdp( | |
parallel_dims: "ParallelDims", | ||
job_config: JobConfig, | ||
): | ||
|
||
""" | ||
Apply data parallelism to the model. FSDP2 is used here. | ||
""" | ||
|
||
if parallel_dims.cp_enabled: | ||
# Temporary solution to enable FSDP + CP | ||
dp_mesh = init_device_mesh( | ||
world_mesh.device_type, | ||
(parallel_dims.dp * parallel_dims.cp,), | ||
mesh_dim_names=["dp"], | ||
) | ||
# This mesh also includes cp degree if it is larger than 1. | ||
if parallel_dims.dp_type == "fsdp": | ||
dp_mesh = world_mesh["dp"] | ||
else: | ||
assert parallel_dims.dp_type == "hsdp", parallel_dims.dp_type | ||
dp_mesh = world_mesh["dp"] | ||
|
||
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
dp_mesh = dp_mesh.reshape( | ||
(parallel_dims.dp_replicate, -1), | ||
("dp_replicate", "dp_shard"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For HSDP + CP, suppose it's a (2, 2, 2) mesh where HSDP + CP applied, how could CP be used together with HSDP? is CP somewhat need to be merge into one of the HSDP dimensions? i.e. HSDP mesh would be (2, 2), and after merging CP it would become (2, 4) or (4, 2)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be (2, 4) for fully_shard and it would be (4,) (the merged dimension of first two dimensions of the world_mesh) for data loader. |
||
) | ||
# assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
|
||
mp_policy = MixedPrecisionPolicy( | ||
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may extend
ParallelDims
to have helpers to get different submeshs. However, the naming can be tricky as there are multiple meaning ofdp_mesh
.