Skip to content

Commit 74b11bf

Browse files
Chao1Hanmengfei25
andauthored
support high priority stream (#1715)
Support high priority stream for xccl, test case add in #2049 We need merge this pr first and upstream op register pytorch/pytorch#163049 and then test case could be pass --------- Co-authored-by: mengfei25 <[email protected]>
1 parent df1d7ad commit 74b11bf

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
322322
return true;
323323
}
324324

325-
ProcessGroupXCCL::Options::Options() : Backend::Options(XCCL_BACKEND_NAME) {}
325+
ProcessGroupXCCL::Options::Options(bool is_high_priority_stream)
326+
: Backend::Options(XCCL_BACKEND_NAME),
327+
is_high_priority_stream(is_high_priority_stream) {}
326328

327329
static std::atomic<size_t> process_group_id = 0;
328330

@@ -351,7 +353,7 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
351353
}
352354

353355
ProcessGroupXCCL::ProcessGroupXCCL(
354-
const c10::intrusive_ptr<Store>& store,
356+
c10::intrusive_ptr<Store> store,
355357
int rank,
356358
int size,
357359
c10::intrusive_ptr<Options> options)
@@ -377,7 +379,10 @@ ProcessGroupXCCL::ProcessGroupXCCL(
377379
std::string torch_distributed_debug =
378380
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
379381
LOG(INFO) << logPrefix() << "ProcessGroupXCCL initialization options: "
380-
<< "size: " << size << ", global rank: " << rank_;
382+
<< "size: " << size << ", global rank: " << rank_
383+
<< ", USE_HIGH_PRIORITY_STREAM: "
384+
<< options_->is_high_priority_stream
385+
<< ", PG Name: " << options_->group_name;
381386

382387
LOG(INFO) << logPrefix() << "ProcessGroupXCCL environments: "
383388
<< "XCCL version: " << XcclVersion
@@ -534,9 +539,9 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
534539
rank = p2pRank;
535540
}
536541

537-
c10::impl::VirtualGuardImpl impl(device.type());
538-
c10::Stream stream =
539-
impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false);
542+
bool force_high = getCvarBool(TORCH_XCCL_HIGH_PRIORITY, false);
543+
c10::Stream stream = at::xpu::getStreamFromPool(
544+
options_->is_high_priority_stream || force_high);
540545
sycl::queue& q = c10::xpu::XPUStream(stream).queue();
541546

542547
auto ctx = ccl::create_context(q.get_context());

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
#include <xccl/ProcessGroupXCCLMonitor.hpp>
2525
namespace c10d {
2626

27+
static std::vector<std::string> TORCH_XCCL_HIGH_PRIORITY = {
28+
"TORCH_XCCL_HIGH_PRIORITY"};
29+
2730
static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = {
2831
"TORCH_XCCL_BLOCKING_WAIT",
2932
"XCCL_BLOCKING_WAIT"};
@@ -118,18 +121,19 @@ class TORCH_API ProcessGroupXCCL : public Backend {
118121
};
119122

120123
struct Options : public Backend::Options {
121-
explicit Options();
124+
explicit Options(bool is_high_priority_stream = false);
122125

123-
static c10::intrusive_ptr<Options> create() {
124-
return c10::make_intrusive<Options>();
126+
static c10::intrusive_ptr<Options> create(
127+
bool is_high_priority_stream = false) {
128+
return c10::make_intrusive<Options>(is_high_priority_stream);
125129
}
126-
130+
bool is_high_priority_stream;
127131
std::vector<uint64_t> global_ranks_in_group;
128132
std::string group_name;
129133
};
130134

131135
ProcessGroupXCCL(
132-
const c10::intrusive_ptr<Store>& store,
136+
c10::intrusive_ptr<Store> store,
133137
int rank,
134138
int size,
135139
c10::intrusive_ptr<Options> options = Options::create());
@@ -138,11 +142,16 @@ class TORCH_API ProcessGroupXCCL : public Backend {
138142
const c10::intrusive_ptr<Store>& store,
139143
int rank,
140144
int size,
141-
const std::string& groupName)
142-
: ProcessGroupXCCL(store, rank, size) {}
145+
const std::string& groupName,
146+
c10::intrusive_ptr<Options> options = Options::create())
147+
: ProcessGroupXCCL(store, rank, size, std::move(options)) {}
143148

144149
~ProcessGroupXCCL() override;
145150

151+
c10::intrusive_ptr<Options> getOptions() {
152+
return options_;
153+
}
154+
146155
const std::string getBackendName() const override {
147156
return std::string(XCCL_BACKEND_NAME);
148157
}

0 commit comments

Comments
 (0)