File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change 1- #include < c10d/ProcessGroupMPI.hpp>
1+ #define USE_C10D_MPI
2+ #include < torch/csrc/distributed/c10d/Work.hpp>
3+ #include < torch/csrc/distributed/c10d/ProcessGroup.hpp>
4+ #include < torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
25#include < torch/torch.h>
36#include < iostream>
47
@@ -35,8 +38,8 @@ struct Model : torch::nn::Module {
3538};
3639
3740void waitWork (
38- std::shared_ptr <c10d::ProcessGroupMPI> pg,
39- std::vector<std::shared_ptr <c10d::ProcessGroup ::Work>> works) {
41+ c10::intrusive_ptr <c10d::ProcessGroupMPI> pg,
42+ std::vector<c10::intrusive_ptr <c10d::Work>> works) {
4043 for (auto & work : works) {
4144 try {
4245 work->wait ();
@@ -115,7 +118,7 @@ int main(int argc, char* argv[]) {
115118 // since this synchronizes parameters after backward pass while DDP
116119 // overlaps synchronizing parameters and computing gradients in backward
117120 // pass
118- std::vector<std::shared_ptr<::c10d::ProcessGroup ::Work>> works;
121+ std::vector<::c10::intrusive_ptr<::c10d ::Work>> works;
119122 for (auto & param : model->named_parameters ()) {
120123 std::vector<torch::Tensor> tmp = {param.value ().grad ()};
121124 auto work = pg->allreduce (tmp);
You can’t perform that action at this time.
0 commit comments