|
| 1 | +/** |
| 2 | + * Copyright (c) 2018-present, Facebook, Inc. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. An additional grant |
| 7 | + * of patent rights can be found in the PATENTS file in the same directory. |
| 8 | + */ |
| 9 | + |
| 10 | +#include "gloo/allgather.h" |
| 11 | + |
| 12 | +#include <cstring> |
| 13 | + |
| 14 | +#include "gloo/common/logging.h" |
| 15 | +#include "gloo/types.h" |
| 16 | + |
| 17 | +namespace gloo { |
| 18 | + |
| 19 | +void allgather( |
| 20 | + const std::shared_ptr<Context>& context, |
| 21 | + AllgatherOptions& opts) { |
| 22 | + std::unique_ptr<transport::UnboundBuffer> tmpInBuffer; |
| 23 | + std::unique_ptr<transport::UnboundBuffer> tmpOutBuffer; |
| 24 | + transport::UnboundBuffer* in = nullptr; |
| 25 | + transport::UnboundBuffer* out = nullptr; |
| 26 | + const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag); |
| 27 | + |
| 28 | + // Sanity checks |
| 29 | + GLOO_ENFORCE(opts.elementSize > 0); |
| 30 | + const auto recvRank = (context->size + context->rank - 1) % context->size; |
| 31 | + GLOO_ENFORCE( |
| 32 | + context->getPair(recvRank), |
| 33 | + "missing connection between rank " + std::to_string(context->rank) + |
| 34 | + " (this process) and rank " + std::to_string(recvRank)); |
| 35 | + const auto sendRank = (context->size + context->rank + 1) % context->size; |
| 36 | + GLOO_ENFORCE( |
| 37 | + context->getPair(sendRank), |
| 38 | + "missing connection between rank " + std::to_string(context->rank) + |
| 39 | + " (this process) and rank " + std::to_string(sendRank)); |
| 40 | + |
| 41 | + // Figure out pointer to input buffer |
| 42 | + if (opts.inBuffer) { |
| 43 | + in = opts.inBuffer.get(); |
| 44 | + } else if (opts.inPtr != nullptr) { |
| 45 | + GLOO_ENFORCE(opts.inElements > 0); |
| 46 | + tmpInBuffer = context->createUnboundBuffer( |
| 47 | + opts.inPtr, opts.inElements * opts.elementSize); |
| 48 | + in = tmpInBuffer.get(); |
| 49 | + } |
| 50 | + |
| 51 | + // Figure out pointer to output buffer |
| 52 | + if (opts.outBuffer) { |
| 53 | + out = opts.outBuffer.get(); |
| 54 | + } else { |
| 55 | + GLOO_ENFORCE(opts.outPtr != nullptr); |
| 56 | + GLOO_ENFORCE(opts.outElements > 0); |
| 57 | + tmpOutBuffer = context->createUnboundBuffer( |
| 58 | + opts.outPtr, opts.outElements * opts.elementSize); |
| 59 | + out = tmpOutBuffer.get(); |
| 60 | + } |
| 61 | + |
| 62 | + if (in != nullptr) { |
| 63 | + GLOO_ENFORCE_EQ(out->size, in->size * context->size); |
| 64 | + } else { |
| 65 | + GLOO_ENFORCE_EQ(out->size % context->size, 0); |
| 66 | + } |
| 67 | + |
| 68 | + const size_t inBytes = out->size / context->size; |
| 69 | + const size_t outBytes = out->size; |
| 70 | + |
| 71 | + // If the input buffer is specified, this is NOT an in place operation, |
| 72 | + // and the output buffer needs to be primed with the input. |
| 73 | + if (in != nullptr) { |
| 74 | + memcpy( |
| 75 | + static_cast<uint8_t*>(out->ptr) + context->rank * in->size, |
| 76 | + static_cast<uint8_t*>(in->ptr), |
| 77 | + in->size); |
| 78 | + } |
| 79 | + |
| 80 | + // The chunk size may not be divisible by 2; use dynamic lookup. |
| 81 | + std::array<size_t, 2> chunkSize; |
| 82 | + chunkSize[0] = inBytes / 2; |
| 83 | + chunkSize[1] = inBytes - chunkSize[0]; |
| 84 | + std::array<size_t, 2> chunkOffset; |
| 85 | + chunkOffset[0] = 0; |
| 86 | + chunkOffset[1] = chunkSize[0]; |
| 87 | + |
| 88 | + for (auto i = 0; i < (context->size - 1) * 2; i++) { |
| 89 | + const size_t sendSegment = context->size + context->rank - (i / 2); |
| 90 | + const size_t recvSegment = sendSegment - 1; |
| 91 | + size_t sendOffset = |
| 92 | + ((sendSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes; |
| 93 | + size_t recvOffset = |
| 94 | + ((recvSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes; |
| 95 | + size_t size = chunkSize[i & 0x1]; |
| 96 | + if (i < 2) { |
| 97 | + out->send(sendRank, slot, sendOffset, size); |
| 98 | + out->recv(recvRank, slot, recvOffset, size); |
| 99 | + continue; |
| 100 | + } |
| 101 | + |
| 102 | + // Wait for pending operations to complete to synchronize with the |
| 103 | + // previous iteration. Because we kick off two operations before |
| 104 | + // getting here we always wait for the next-to-last operation. |
| 105 | + out->waitSend(); |
| 106 | + out->waitRecv(); |
| 107 | + out->send(sendRank, slot, sendOffset, size); |
| 108 | + out->recv(recvRank, slot, recvOffset, size); |
| 109 | + } |
| 110 | + |
| 111 | + // Wait for completes |
| 112 | + for (auto i = 0; i < 2; i++) { |
| 113 | + out->waitSend(); |
| 114 | + out->waitRecv(); |
| 115 | + } |
| 116 | +} |
| 117 | + |
| 118 | +} // namespace gloo |
0 commit comments