1+ #ifdef USE_C10D_XCCL
2+
13#include < torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp>
24#include < fstream>
3- #include < mutex>
4- #include < sstream>
5-
6- #ifdef USE_C10D_XCCL
75#include < comm/XPUGuard.h>
86#include < exception>
97#include < map>
8+ #include < sstream>
109#include < stdexcept>
1110#include < tuple>
1211#include < unordered_set>
1312#include < utility>
1413
1514#include < ATen/detail/FunctionTraits.h>
1615#include < c10/core/DeviceType.h>
17- #include < c10/util/CallOnce.h>
18- #include < c10/util/Exception.h>
19- #include < c10/util/Logging.h>
2016#include < c10/util/Optional.h>
21- #include < c10/util/irange.h>
22- #include < torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
23- #include < torch/csrc/distributed/c10d/TraceUtils.h>
24- #include < torch/csrc/distributed/c10d/Utils.hpp>
25- #include < torch/torch.h>
2617
2718namespace c10d {
2819
@@ -61,36 +52,6 @@ std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
6152 {at::kFloat8_e5m2fnuz , ccl::datatype::uint8},
6253};
6354
64- XCCL_KVS kvs;
65- std::mutex kvs_mutex;
66-
67- XCCL_KVS get_kvs (int rank, c10d::Store& store) {
68- std::lock_guard<std::mutex> lock (kvs_mutex);
69- if (kvs)
70- return kvs;
71- std::string storeKey = " xccl_kvs" ;
72-
73- // Rank 0 broadcast the bootstrap network information to other ranks
74- if (rank == 0 ) {
75- kvs = ccl::create_main_kvs ();
76- ccl::kvs::address_type main_addr = kvs->get_address ();
77- auto ccl_kvs_addr =
78- std::vector<uint8_t >(main_addr.begin (), main_addr.end ());
79- store.set (storeKey, ccl_kvs_addr);
80- } else {
81- auto ccl_kvs_addr = store.get (storeKey);
82- if (ccl_kvs_addr.size () != ccl::kvs::address_max_size) {
83- throw std::runtime_error (" Unexpected ccl kvs addr from the store\n " );
84- }
85- ccl::kvs::address_type main_addr;
86- std::copy_n (
87- ccl_kvs_addr.begin (), ccl::kvs::address_max_size, main_addr.begin ());
88- kvs = ccl::create_kvs (main_addr);
89- }
90-
91- return kvs;
92- }
93-
9455bool check_same_size (const std::vector<at::Tensor>& input_tensors) {
9556 for (const auto & input_tensor : input_tensors) {
9657 if (!input_tensors[0 ].is_same_size (input_tensor)) {
@@ -159,23 +120,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
159120 }
160121 return xcclOps.at (reduceOp);
161122 } catch (const std::out_of_range&) {
162- switch (reduceOp) {
163- case ReduceOp::AVG:
164- C10_THROW_ERROR (ValueError, " Cannot use ReduceOp AVG with XCCL" );
165- break ;
166- case ReduceOp::BAND:
167- C10_THROW_ERROR (ValueError, " Cannot use ReduceOp.BAND with XCCL" );
168- break ;
169- case ReduceOp::BOR:
170- C10_THROW_ERROR (ValueError, " Cannot use ReduceOp.BOR with XCCL" );
171- break ;
172- case ReduceOp::BXOR:
173- C10_THROW_ERROR (ValueError, " Cannot use ReduceOp.BXOR with XCCL" );
174- break ;
175- default :
176- C10_THROW_ERROR (ValueError, " Unhandled ReduceOp" );
177- break ;
178- }
123+ C10_THROW_ERROR (
124+ ValueError,
125+ " Cannot use ReduceOp." + reduce_op_to_string (reduceOp) + " with XCCL" );
179126 }
180127}
181128
@@ -210,20 +157,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)
210157
211158ProcessGroupXCCL::WorkXCCL::~WorkXCCL () = default ;
212159
213- bool ProcessGroupXCCL::WorkXCCL::checkTimeout (
214- std::optional<std::chrono::milliseconds> timeout) {
215- auto currentTimepoint = std::chrono::steady_clock::now ();
216- auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
217- currentTimepoint - workStartTime_);
218- std::chrono::milliseconds opTimeout = std::chrono::milliseconds (60000 );
219-
220- auto workTimeout = timeout ? *timeout : opTimeout;
221-
222- if (timeElapsed < workTimeout)
223- return false ;
224- return true ;
225- }
226-
227160bool ProcessGroupXCCL::WorkXCCL::isCompleted () {
228161 if (xcclEndEvent_ && xcclEndEvent_->query ()) {
229162 return true ;
@@ -235,23 +168,23 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() {
235168 synchronizeInternal (kNoTimeout );
236169}
237170
238- void ProcessGroupXCCL::WorkXCCL::synchronizeStream () {
239- auto currentStream = at::xpu::getCurrentXPUStream (device_.index ());
240- // Block the current stream on the XCCL stream
241- xcclEndEvent_->block (currentStream);
242- }
243-
244171void ProcessGroupXCCL::WorkXCCL::synchronizeInternal (
245172 std::chrono::milliseconds timeout) {
246- synchronizeStream ( );
247-
173+ auto currentStream = at::xpu::getCurrentXPUStream (device_. index () );
174+ xcclEndEvent_-> block (currentStream);
248175 if (blockingWait_) {
249176 while (!isCompleted ()) {
250- bool timedOut = checkTimeout (
251- timeout == kNoTimeout ? std::nullopt : std::make_optional (timeout));
252- if (timedOut) {
253- break ;
177+ auto currentTimepoint = std::chrono::steady_clock::now ();
178+ auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
179+ currentTimepoint - workStartTime_);
180+ if (timeElapsed >= timeout) {
181+ std::string exceptionMsg = c10::str (
182+ " Work ran for " ,
183+ timeElapsed.count (),
184+ " milliseconds before timing out." );
185+ TORCH_CHECK (false , exceptionMsg)
254186 }
187+
255188 std::this_thread::sleep_for (
256189 std::chrono::milliseconds (kSynchronizeBusyWaitMillis ));
257190 }
0 commit comments