Skip to content

Add NCCL PreMul Sum to c10d redce ops #84243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Aug 30, 2022

This is based on #81272 but this conforms to TorchScript Compiler

cc @ptrblck @kwen2501 @aazzolini
cc @zasdfgbnm for visibility to the TODO above

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 30, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 7ebdd05 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 30, 2022
@crcrpar
Copy link
Collaborator Author

crcrpar commented Aug 30, 2022

if this works at least on the public CI, I'll close #84059

@crcrpar
Copy link
Collaborator Author

crcrpar commented Aug 30, 2022

Regarding the failure

2022-08-30T02:10:49.9879118Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-08-30T02:10:49.9879135Z 
2022-08-30T02:10:49.9879245Z Broken ops: [
2022-08-30T02:10:49.9879772Z 	c10d::reduce_(Tensor[] _0, __torch__.torch.classes.c10d.ProcessGroup _1, int _2, int _3, int _4, int _5) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880297Z 	c10d::reduce_scatter_(Tensor[] _0, Tensor[][] _1, __torch__.torch.classes.c10d.ProcessGroup _2, int _3, int _4) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880791Z 	c10d::allreduce_(Tensor[] _0, __torch__.torch.classes.c10d.ProcessGroup _1, int _2, int _3) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880884Z ]

Should I add c10d reduce ops to ALLOW_LIST like https://github.com/crcrpar/pytorch/blob/a0c6e7499ea81fb0da4858a7ebf27a88c0612493/test/forward_backward_compatibility/check_forward_backward_compatibility.py#L123?

@crcrpar crcrpar changed the title Resubmit #81272 Add NCCL PreMul Sum to c10d redce ops Aug 30, 2022
@kwen2501
Copy link
Contributor

Regarding the failure

2022-08-30T02:10:49.9879118Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-08-30T02:10:49.9879135Z 
2022-08-30T02:10:49.9879245Z Broken ops: [
2022-08-30T02:10:49.9879772Z 	c10d::reduce_(Tensor[] _0, __torch__.torch.classes.c10d.ProcessGroup _1, int _2, int _3, int _4, int _5) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880297Z 	c10d::reduce_scatter_(Tensor[] _0, Tensor[][] _1, __torch__.torch.classes.c10d.ProcessGroup _2, int _3, int _4) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880791Z 	c10d::allreduce_(Tensor[] _0, __torch__.torch.classes.c10d.ProcessGroup _1, int _2, int _3) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880884Z ]

Should I add c10d reduce ops to ALLOW_LIST like https://github.com/crcrpar/pytorch/blob/a0c6e7499ea81fb0da4858a7ebf27a88c0612493/test/forward_backward_compatibility/check_forward_backward_compatibility.py#L123?

@H-Huang seems to have encounter a similar warning. Maybe he knows how to respond to it.

@crcrpar

This comment was marked as outdated.

@H-Huang
Copy link
Member

H-Huang commented Sep 1, 2022

Regarding the failure

2022-08-30T02:10:49.9879118Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-08-30T02:10:49.9879135Z 
2022-08-30T02:10:49.9879245Z Broken ops: [
2022-08-30T02:10:49.9879772Z 	c10d::reduce_(Tensor[] _0, __torch__.torch.classes.c10d.ProcessGroup _1, int _2, int _3, int _4, int _5) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880297Z 	c10d::reduce_scatter_(Tensor[] _0, Tensor[][] _1, __torch__.torch.classes.c10d.ProcessGroup _2, int _3, int _4) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880791Z 	c10d::allreduce_(Tensor[] _0, __torch__.torch.classes.c10d.ProcessGroup _1, int _2, int _3) -> __torch__.torch.classes.c10d.Work _0
2022-08-30T02:10:49.9880884Z ]

Should I add c10d reduce ops to ALLOW_LIST like https://github.com/crcrpar/pytorch/blob/a0c6e7499ea81fb0da4858a7ebf27a88c0612493/test/forward_backward_compatibility/check_forward_backward_compatibility.py#L123?

@crcrpar @kwen2501 FYI: I am going to update the allow list in this PR to allow all changes to all ops for the dispatchable collectives feature https://github.com/pytorch/pytorch/pull/83735/files#diff-236fbde71e59cb1597cac177a83e49fb62b30770eec55c4e7a0f2650b9eb6203R274-R275. PR will be merged in the next day or 2. Feel free to also include this change.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 1, 2022
crcrpar and others added 5 commits September 2, 2022 08:12
- have `_SupplementBase` and `ReduceOp` inherit
`torch::CustomClassHolder`
- `def` only `c10d::ReduceOp` in `Ops.cpp`
- rather `c10::intrusive_ptr<ReduceOp>`, not `int64_t` in dispatch

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the contribution!

@crcrpar
Copy link
Collaborator Author

crcrpar commented Sep 2, 2022

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Contributor

github-actions bot commented Sep 2, 2022

Hey @crcrpar.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Sep 7, 2022
Summary:
This is based on #81272 but this conforms to TorchScript Compiler

## TODO
- [ ] Update https://github.com/pytorch/pytorch/blob/abaf8112e6d6bed2a5d33dcbc1d46ed20b8e80de/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp#L64-L73 to use `ReduceOp::RedOpType`. In my first try with `USE_SYSTEM_UCC=1`, this change wasn't necessary (I think) because of `ReduceOp::RedOpType` operator. That being said, I want to make it more explicit.

cc ptrblck kwen2501 aazzolini
cc zasdfgbnm for visibility to the TODO above

Pull Request resolved: #84243
Approved by: https://github.com/kwen2501

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ab6c57217a97438c8e13952a407e42873e2259f3

Reviewed By: mehtanirav, izaitsevfb

Differential Revision: D39277627

fbshipit-source-id: 039c6eef8c4d1c42a18273edb43b40888176d867
@crcrpar crcrpar deleted the ncclpremulsum branch September 30, 2022 23:20
pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2022
Summary:
- Customize the metaclass of `torch.distributed.distributed_c10d.ReduceOp` for the sake of custom `__instancecheck__`
- Add `copy.copy`, `copy.deepcopy`, and `pickle` support with tests

Rel:
- #81272
- #84243
- #87191
- #87303
- #87555

Ref:
- pybind/pybind11#2696

Pull Request resolved: #88275
Approved by: https://github.com/wanchaol
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
)

Summary:
- Customize the metaclass of `torch.distributed.distributed_c10d.ReduceOp` for the sake of custom `__instancecheck__`
- Add `copy.copy`, `copy.deepcopy`, and `pickle` support with tests

Rel:
- pytorch#81272
- pytorch#84243
- pytorch#87191
- pytorch#87303
- pytorch#87555

Ref:
- pybind/pybind11#2696

Pull Request resolved: pytorch#88275
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants