Skip to content

Commit afa2adc

Browse files
committed
Support broadcast
1 parent 4c3f49f commit afa2adc

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,42 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
546546
OpType::COALESCED);
547547
}
548548

549+
c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast(
550+
std::vector<at::Tensor>& tensors,
551+
const BroadcastOptions& opts) {
552+
TORCH_CHECK(
553+
tensors.size() == 1, "Expecting one tensor only but got multiple");
554+
auto tensor = tensors.back();
555+
if (tensor.is_complex()) {
556+
tensor = at::view_as_real(tensor);
557+
}
558+
check_xpu_single_tensor(tensor);
559+
560+
const auto root = opts.rootRank + opts.rootTensor;
561+
562+
return collective(
563+
tensor,
564+
tensor,
565+
[&](at::Tensor& input,
566+
at::Tensor& output,
567+
ccl::broadcast_attr attr,
568+
xcclComm_t& comm,
569+
ccl::stream& stream) {
570+
auto xcclDataType = getXcclDataType(input.scalar_type());
571+
ccl::event ret_evt;
572+
ret_evt = ccl::broadcast(
573+
input.data_ptr(),
574+
(size_t)input.numel(),
575+
xcclDataType,
576+
root,
577+
comm,
578+
stream,
579+
attr);
580+
return ret_evt;
581+
},
582+
OpType::BROADCAST);
583+
}
584+
549585
} // namespace c10d
550586

551587
#endif // USE_C10D_XCCL

torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
202202

203203
c10::intrusive_ptr<Work> broadcast(
204204
std::vector<at::Tensor>& tensors,
205-
const BroadcastOptions& opts = BroadcastOptions()) override {
206-
TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented");
207-
}
205+
const BroadcastOptions& opts = BroadcastOptions()) override;
208206

209207
c10::intrusive_ptr<Work> allgather(
210208
std::vector<std::vector<at::Tensor>>& outputTensors,

0 commit comments

Comments
 (0)