File tree Expand file tree Collapse file tree 2 files changed +37
-3
lines changed
torch/csrc/distributed/c10d Expand file tree Collapse file tree 2 files changed +37
-3
lines changed Original file line number Diff line number Diff line change @@ -546,6 +546,42 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
546546 OpType::COALESCED);
547547}
548548
549+ c10::intrusive_ptr<Work> ProcessGroupXCCL::broadcast (
550+ std::vector<at::Tensor>& tensors,
551+ const BroadcastOptions& opts) {
552+ TORCH_CHECK (
553+ tensors.size () == 1 , " Expecting one tensor only but got multiple" );
554+ auto tensor = tensors.back ();
555+ if (tensor.is_complex ()) {
556+ tensor = at::view_as_real (tensor);
557+ }
558+ check_xpu_single_tensor (tensor);
559+
560+ const auto root = opts.rootRank + opts.rootTensor ;
561+
562+ return collective (
563+ tensor,
564+ tensor,
565+ [&](at::Tensor& input,
566+ at::Tensor& output,
567+ ccl::broadcast_attr attr,
568+ xcclComm_t& comm,
569+ ccl::stream& stream) {
570+ auto xcclDataType = getXcclDataType (input.scalar_type ());
571+ ccl::event ret_evt;
572+ ret_evt = ccl::broadcast (
573+ input.data_ptr (),
574+ (size_t )input.numel (),
575+ xcclDataType,
576+ root,
577+ comm,
578+ stream,
579+ attr);
580+ return ret_evt;
581+ },
582+ OpType::BROADCAST);
583+ }
584+
549585} // namespace c10d
550586
551587#endif // USE_C10D_XCCL
Original file line number Diff line number Diff line change @@ -202,9 +202,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
202202
203203 c10::intrusive_ptr<Work> broadcast (
204204 std::vector<at::Tensor>& tensors,
205- const BroadcastOptions& opts = BroadcastOptions()) override {
206- TORCH_CHECK (false , " ProcessGroupXCCL::broadcast not implemented" );
207- }
205+ const BroadcastOptions& opts = BroadcastOptions()) override ;
208206
209207 c10::intrusive_ptr<Work> allgather (
210208 std::vector<std::vector<at::Tensor>>& outputTensors,
You can’t perform that action at this time.
0 commit comments