-
Notifications
You must be signed in to change notification settings - Fork 613
[SimpleFSDP] Add support for ddp+tp #1250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me!
Two more comments:
- For the comment on
https://github.com/pytorch/torchtitan/pull/1250/files#diff-02c09227aed7868aae47b1b0b6cb3b5105b84f2543cc2dea9c5f3a7cb265eeadR180
I think we need to update it because
For FSDP, it's all-gather in forward and reduce-scatter in backward
For DDP, it's all-reduce in backward.
Note these are in additional to mixed precision dtype conversion.
Let's actually verify such behavior with trace in the PR summary, as we haven't verified it before. - Let's also verify the numerics by comparing "FSDP 2" vs. "DDP2+TP2" (where we assume FSDP as the ground truth).
d144900 to
a201f90
Compare
a201f90 to
05dc2ff
Compare
Updated. Thank you! |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numerical convergence: As seen, the loss convergence is close for [ddp:2, tp:2] and [fsdp:2, tp:2].
This actually looks concerning. I would expect the loss to be exactly the same between the two, if random seed, determinism, and the same initialization of parameters are used.
Thinking about the possible reasons, I think parameter init is not controlled -- FSDP would init a sharded tensor on dp mesh, whereas DDP would init a replicate tensor across the dp mesh.
To remove this factor, let's init a seed checkpoint first, and then kickoff two separate runs loading the same checkpoint.
https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md#how-to-create-a-seed-checkpoint
(Note that you may have to copy/move/remove of the checkpoints to do avoid not loading from step-0.)
Seems like I forgot to set the seed to be the same. With the newly updated pic, the discrepancy between DDP & FSDP + TP is much smaller. Sorry for the confusion here. |
|
DDP + TP performance is twice faster than FSDP + TP. Is this expected? Does this mean the allgathers are exposed? Or there are performance optimizations that are not turned on yet? |
Yes, with only front-end, SimpleFSDP exposes all of its communications. The optimizations (pre-fetching & bucketing) are performed in the compiler backend, which has not been turned on here. |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job! Thank you for doing all the tests & verifications!
I agree we've isolated the issue to DDP+MPT. Let's follow up in a separate PR.
|
Is SimpleFSDP also supported in torchtune? Hoping that the both projects share more code and do not spend twice time for reimplementing the same disttributed features... |
This is a follow-up on the previous dtensor redistribute PR: #150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.  Pull Request resolved: #154975 Approved by: https://github.com/tianyu-l
SimpleFSDP is not supported in torchtune yet. SimpleFSDP is more of a type of FSDP users can apply on top of their model. For the front-end wrapping, all users need to do is call simple_fsdp.py for FSDP and the rest of parallelism definitions are unchanged. The FSDP optimizations (bucketing & reordering) are done in the TorchInductor backend. I agree, for pre-training and post-training, the optimal operator bucketing strategy may be different. But the bucketing & reordering are done in TorchInductor and should be independent of torchtitan, torchtune, or any other repos. |
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.  Pull Request resolved: pytorch#154975 Approved by: https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.  Pull Request resolved: pytorch#154975 Approved by: https://github.com/tianyu-l



As titled, this PR adds support for DDP+TP under SimpleFSDP's
replicatemode.all-reduce.The loss convergence is the same for [ddp:2, tp:2] and [fsdp:2, tp:2] (without mixed-precision training)
