@@ -322,7 +322,9 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
322
322
return true ;
323
323
}
324
324
325
- ProcessGroupXCCL::Options::Options () : Backend::Options(XCCL_BACKEND_NAME) {}
325
+ ProcessGroupXCCL::Options::Options (bool is_high_priority_stream)
326
+ : Backend::Options(XCCL_BACKEND_NAME),
327
+ is_high_priority_stream(is_high_priority_stream) {}
326
328
327
329
static std::atomic<size_t > process_group_id = 0 ;
328
330
@@ -351,7 +353,7 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
351
353
}
352
354
353
355
ProcessGroupXCCL::ProcessGroupXCCL (
354
- const c10::intrusive_ptr<Store>& store,
356
+ c10::intrusive_ptr<Store> store,
355
357
int rank,
356
358
int size,
357
359
c10::intrusive_ptr<Options> options)
@@ -377,7 +379,10 @@ ProcessGroupXCCL::ProcessGroupXCCL(
377
379
std::string torch_distributed_debug =
378
380
getCvarString ({" TORCH_DISTRIBUTED_DEBUG" }, OFF.c_str ());
379
381
LOG (INFO) << logPrefix () << " ProcessGroupXCCL initialization options: "
380
- << " size: " << size << " , global rank: " << rank_;
382
+ << " size: " << size << " , global rank: " << rank_
383
+ << " , USE_HIGH_PRIORITY_STREAM: "
384
+ << options_->is_high_priority_stream
385
+ << " , PG Name: " << options_->group_name ;
381
386
382
387
LOG (INFO) << logPrefix () << " ProcessGroupXCCL environments: "
383
388
<< " XCCL version: " << XcclVersion
@@ -534,9 +539,9 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
534
539
rank = p2pRank;
535
540
}
536
541
537
- c10::impl::VirtualGuardImpl impl (device. type () );
538
- c10::Stream stream =
539
- impl. getStreamFromGlobalPool (device, /* isHighPriority= */ false );
542
+ bool force_high = getCvarBool (TORCH_XCCL_HIGH_PRIORITY, false );
543
+ c10::Stream stream = at::xpu::getStreamFromPool (
544
+ options_-> is_high_priority_stream || force_high );
540
545
sycl::queue& q = c10::xpu::XPUStream (stream).queue ();
541
546
542
547
auto ctx = ccl::create_context (q.get_context ());
0 commit comments