-
Notifications
You must be signed in to change notification settings - Fork 383
[RFC] Enable HSDP + CP #463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -463,7 +462,8 @@ def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig): | |||
""" | |||
if parallel_dims.tp_enabled or parallel_dims.pp_enabled: | |||
raise NotImplementedError("CP + TP or CP + PP are not supported yet.") | |||
cp_mesh = world_mesh["cp"] | |||
dp_mesh = world_mesh["dp"] | |||
cp_mesh = dp_mesh.reshape((-1, parallel_dims.cp), ("dp", "cp"))["cp"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may extend ParallelDims
to have helpers to get different submeshs. However, the naming can be tricky as there are multiple meaning of dp_mesh
.
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names | ||
dp_mesh = dp_mesh.reshape( | ||
(parallel_dims.dp_replicate, -1), | ||
("dp_replicate", "dp_shard"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For HSDP + CP, suppose it's a (2, 2, 2) mesh where HSDP + CP applied, how could CP be used together with HSDP? is CP somewhat need to be merge into one of the HSDP dimensions? i.e. HSDP mesh would be (2, 2), and after merging CP it would become (2, 4) or (4, 2)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be (2, 4) for fully_shard and it would be (4,) (the merged dimension of first two dimensions of the world_mesh) for data loader.
Stack from ghstack (oldest at bottom):
Summary:
This PR demonstrates how to use DeviceMesh reshape to enable HSDP + CP