Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Jun 1, 2025

As titled, this PR adds support for DDP+TP under SimpleFSDP's replicate mode.

  1. Profile trace for DDP. As seen, the DDP bwd communication is all-reduce.
Screenshot 2025-06-01 at 1 10 07 PM
  1. Numerical convergence: As seen, the loss convergence discrepancy is in 1e-3 for [ddp:2, tp:2] and [fsdp:2, tp:2] (with mixed-precision training)
Screenshot 2025-06-01 at 11 39 49 PM

The loss convergence is the same for [ddp:2, tp:2] and [fsdp:2, tp:2] (without mixed-precision training)
Screenshot 2025-06-02 at 11 59 09 AM

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 1, 2025
@ruisizhang123 ruisizhang123 requested a review from tianyu-l June 1, 2025 04:32
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me!

Two more comments:

  1. For the comment on
    https://github.com/pytorch/torchtitan/pull/1250/files#diff-02c09227aed7868aae47b1b0b6cb3b5105b84f2543cc2dea9c5f3a7cb265eeadR180
    I think we need to update it because
    For FSDP, it's all-gather in forward and reduce-scatter in backward
    For DDP, it's all-reduce in backward.
    Note these are in additional to mixed precision dtype conversion.
    Let's actually verify such behavior with trace in the PR summary, as we haven't verified it before.
  2. Let's also verify the numerics by comparing "FSDP 2" vs. "DDP2+TP2" (where we assume FSDP as the ground truth).

@ruisizhang123 ruisizhang123 force-pushed the ruisi/ddp+tp branch 2 times, most recently from d144900 to a201f90 Compare June 1, 2025 20:43
@ruisizhang123
Copy link
Contributor Author

Makes sense to me!

Two more comments:

  1. For the comment on
    https://github.com/pytorch/torchtitan/pull/1250/files#diff-02c09227aed7868aae47b1b0b6cb3b5105b84f2543cc2dea9c5f3a7cb265eeadR180
    I think we need to update it because
    For FSDP, it's all-gather in forward and reduce-scatter in backward
    For DDP, it's all-reduce in backward.
    Note these are in additional to mixed precision dtype conversion.
    Let's actually verify such behavior with trace in the PR summary, as we haven't verified it before.
  2. Let's also verify the numerics by comparing "FSDP 2" vs. "DDP2+TP2" (where we assume FSDP as the ground truth).

Updated. Thank you!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numerical convergence: As seen, the loss convergence is close for [ddp:2, tp:2] and [fsdp:2, tp:2].

This actually looks concerning. I would expect the loss to be exactly the same between the two, if random seed, determinism, and the same initialization of parameters are used.

Thinking about the possible reasons, I think parameter init is not controlled -- FSDP would init a sharded tensor on dp mesh, whereas DDP would init a replicate tensor across the dp mesh.

To remove this factor, let's init a seed checkpoint first, and then kickoff two separate runs loading the same checkpoint.
https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md#how-to-create-a-seed-checkpoint
(Note that you may have to copy/move/remove of the checkpoints to do avoid not loading from step-0.)

@ruisizhang123
Copy link
Contributor Author

Numerical convergence: As seen, the loss convergence is close for [ddp:2, tp:2] and [fsdp:2, tp:2].

This actually looks concerning. I would expect the loss to be exactly the same between the two, if random seed, determinism, and the same initialization of parameters are used.

Thinking about the possible reasons, I think parameter init is not controlled -- FSDP would init a sharded tensor on dp mesh, whereas DDP would init a replicate tensor across the dp mesh.

To remove this factor, let's init a seed checkpoint first, and then kickoff two separate runs loading the same checkpoint. https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md#how-to-create-a-seed-checkpoint (Note that you may have to copy/move/remove of the checkpoints to do avoid not loading from step-0.)

Seems like I forgot to set the seed to be the same. With the newly updated pic, the discrepancy between DDP & FSDP + TP is much smaller. Sorry for the confusion here.

@fegin
Copy link
Contributor

fegin commented Jun 2, 2025

DDP + TP performance is twice faster than FSDP + TP. Is this expected? Does this mean the allgathers are exposed? Or there are performance optimizations that are not turned on yet?

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Jun 2, 2025

DDP + TP performance is twice faster than FSDP + TP. Is this expected? Does this mean the allgathers are exposed? Or there are performance optimizations that are not turned on yet?

Yes, with only front-end, SimpleFSDP exposes all of its communications. The optimizations (pre-fetching & bucketing) are performed in the compiler backend, which has not been turned on here.

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Jun 2, 2025

Seems like I forgot to set the seed to be the same. With the newly updated pic, the discrepancy between DDP & FSDP + TP is much smaller. Sorry for the confusion here.

As in the PR description, there are still some minor differences between SimpleFSDP+TP's replicate and fully_shard modes. After turning off TP, the gap still exists and is in a similar range of ~1e-3.

MPT

Tianyu suggested we could turn off MPT. After turning it off, the discrepancy is much smaller to ~1e-4. We need to look into DTensor redistribute to see if it handles DTensor precision differently in replicate (all-reduce) and fully_shard(reduce-scatter) mode.

withoutMPT

I also tested FSDP2 vs DDP loss. The discrepancy is in 1e-4, which is similar to the above SimpleFSDP replicate vs fully_shard (without MPT). We should be good after fixing the SimpleFSDP's MPT bug here.

Screenshot 2025-06-02 at 12 25 17 AM

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job! Thank you for doing all the tests & verifications!
I agree we've isolated the issue to DDP+MPT. Let's follow up in a separate PR.

@tianyu-l tianyu-l merged commit 768cde1 into main Jun 2, 2025
8 checks passed
@tianyu-l tianyu-l deleted the ruisi/ddp+tp branch June 2, 2025 19:06
@vadimkantorov
Copy link

Is SimpleFSDP also supported in torchtune? Hoping that the both projects share more code and do not spend twice time for reimplementing the same disttributed features...

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jun 3, 2025
This is a follow-up on the previous dtensor redistribute PR: #150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.

![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd)

Pull Request resolved: #154975
Approved by: https://github.com/tianyu-l
@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Jun 3, 2025

Is SimpleFSDP also supported in torchtune? Hoping that the both projects share more code and do not spend twice time for reimplementing the same disttributed features...

SimpleFSDP is not supported in torchtune yet. SimpleFSDP is more of a type of FSDP users can apply on top of their model.

For the front-end wrapping, all users need to do is call simple_fsdp.py for FSDP and the rest of parallelism definitions are unchanged.
SimpleFSDP is in experimental version -- we can explore a better way of sharing simple_fsdp.py across repos, but it seems to me we won't be reinventing new wheels for such integration. (maybe @tianyu-l can confirm this)

The FSDP optimizations (bucketing & reordering) are done in the TorchInductor backend. I agree, for pre-training and post-training, the optimal operator bucketing strategy may be different. But the bucketing & reordering are done in TorchInductor and should be independent of torchtitan, torchtune, or any other repos.

iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.

![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd)

Pull Request resolved: pytorch#154975
Approved by: https://github.com/tianyu-l
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.

![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd)

Pull Request resolved: pytorch#154975
Approved by: https://github.com/tianyu-l
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants