Skip to content

Commit 86d3674

Browse files
pieternfacebook-github-bot
authored andcommitted
Add new style allgather collective (#137)
Summary: Like gather in the previous commit, this is a single function and doesn't hold on to any state. This implements a ring allgather. Rebase needed after #136 is merged. Pull Request resolved: #137 Reviewed By: teng-li Differential Revision: D10376419 Pulled By: pietern fbshipit-source-id: bc561bff4c8952531bada14dfe0bc2e9d9c68f78
1 parent 86f1bcb commit 86d3674

File tree

7 files changed

+234
-0
lines changed

7 files changed

+234
-0
lines changed

gloo/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set(GLOO_HDRS)
77
# Compiled sources in root directory
88
list(APPEND GLOO_SRCS
99
"${CMAKE_CURRENT_SOURCE_DIR}/algorithm.cc"
10+
"${CMAKE_CURRENT_SOURCE_DIR}/allgather.cc"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc"
1112
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
1213
"${CMAKE_CURRENT_SOURCE_DIR}/gather.cc"
@@ -15,6 +16,7 @@ list(APPEND GLOO_SRCS
1516

1617
list(APPEND GLOO_HDRS
1718
"${CMAKE_CURRENT_SOURCE_DIR}/algorithm.h"
19+
"${CMAKE_CURRENT_SOURCE_DIR}/allgather.h"
1820
"${CMAKE_CURRENT_SOURCE_DIR}/allgather_ring.h"
1921
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_halving_doubling.h"
2022
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_bcube.h"

gloo/allgather.cc

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree. An additional grant
7+
* of patent rights can be found in the PATENTS file in the same directory.
8+
*/
9+
10+
#include "gloo/allgather.h"
11+
12+
#include <cstring>
13+
14+
#include "gloo/common/logging.h"
15+
#include "gloo/types.h"
16+
17+
namespace gloo {
18+
19+
void allgather(
20+
const std::shared_ptr<Context>& context,
21+
AllgatherOptions& opts) {
22+
std::unique_ptr<transport::UnboundBuffer> tmpInBuffer;
23+
std::unique_ptr<transport::UnboundBuffer> tmpOutBuffer;
24+
transport::UnboundBuffer* in = nullptr;
25+
transport::UnboundBuffer* out = nullptr;
26+
const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag);
27+
28+
// Sanity checks
29+
GLOO_ENFORCE(opts.elementSize > 0);
30+
const auto recvRank = (context->size + context->rank - 1) % context->size;
31+
GLOO_ENFORCE(
32+
context->getPair(recvRank),
33+
"missing connection between rank " + std::to_string(context->rank) +
34+
" (this process) and rank " + std::to_string(recvRank));
35+
const auto sendRank = (context->size + context->rank + 1) % context->size;
36+
GLOO_ENFORCE(
37+
context->getPair(sendRank),
38+
"missing connection between rank " + std::to_string(context->rank) +
39+
" (this process) and rank " + std::to_string(sendRank));
40+
41+
// Figure out pointer to input buffer
42+
if (opts.inBuffer) {
43+
in = opts.inBuffer.get();
44+
} else if (opts.inPtr != nullptr) {
45+
GLOO_ENFORCE(opts.inElements > 0);
46+
tmpInBuffer = context->createUnboundBuffer(
47+
opts.inPtr, opts.inElements * opts.elementSize);
48+
in = tmpInBuffer.get();
49+
}
50+
51+
// Figure out pointer to output buffer
52+
if (opts.outBuffer) {
53+
out = opts.outBuffer.get();
54+
} else {
55+
GLOO_ENFORCE(opts.outPtr != nullptr);
56+
GLOO_ENFORCE(opts.outElements > 0);
57+
tmpOutBuffer = context->createUnboundBuffer(
58+
opts.outPtr, opts.outElements * opts.elementSize);
59+
out = tmpOutBuffer.get();
60+
}
61+
62+
if (in != nullptr) {
63+
GLOO_ENFORCE_EQ(out->size, in->size * context->size);
64+
} else {
65+
GLOO_ENFORCE_EQ(out->size % context->size, 0);
66+
}
67+
68+
const size_t inBytes = out->size / context->size;
69+
const size_t outBytes = out->size;
70+
71+
// If the input buffer is specified, this is NOT an in place operation,
72+
// and the output buffer needs to be primed with the input.
73+
if (in != nullptr) {
74+
memcpy(
75+
static_cast<uint8_t*>(out->ptr) + context->rank * in->size,
76+
static_cast<uint8_t*>(in->ptr),
77+
in->size);
78+
}
79+
80+
// The chunk size may not be divisible by 2; use dynamic lookup.
81+
std::array<size_t, 2> chunkSize;
82+
chunkSize[0] = inBytes / 2;
83+
chunkSize[1] = inBytes - chunkSize[0];
84+
std::array<size_t, 2> chunkOffset;
85+
chunkOffset[0] = 0;
86+
chunkOffset[1] = chunkSize[0];
87+
88+
for (auto i = 0; i < (context->size - 1) * 2; i++) {
89+
const size_t sendSegment = context->size + context->rank - (i / 2);
90+
const size_t recvSegment = sendSegment - 1;
91+
size_t sendOffset =
92+
((sendSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes;
93+
size_t recvOffset =
94+
((recvSegment * inBytes) + chunkOffset[i & 0x1]) % outBytes;
95+
size_t size = chunkSize[i & 0x1];
96+
if (i < 2) {
97+
out->send(sendRank, slot, sendOffset, size);
98+
out->recv(recvRank, slot, recvOffset, size);
99+
continue;
100+
}
101+
102+
// Wait for pending operations to complete to synchronize with the
103+
// previous iteration. Because we kick off two operations before
104+
// getting here we always wait for the next-to-last operation.
105+
out->waitSend();
106+
out->waitRecv();
107+
out->send(sendRank, slot, sendOffset, size);
108+
out->recv(recvRank, slot, recvOffset, size);
109+
}
110+
111+
// Wait for completes
112+
for (auto i = 0; i < 2; i++) {
113+
out->waitSend();
114+
out->waitRecv();
115+
}
116+
}
117+
118+
} // namespace gloo

gloo/allgather.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree. An additional grant
7+
* of patent rights can be found in the PATENTS file in the same directory.
8+
*/
9+
10+
#pragma once
11+
12+
#include "gloo/context.h"
13+
#include "gloo/transport/unbound_buffer.h"
14+
15+
namespace gloo {
16+
17+
struct AllgatherOptions {
18+
// The input and output can either be specified as a unbound buffer
19+
// (that can be cached and reused by the caller), or a literal
20+
// pointer and number of elements stored at that pointer.
21+
//
22+
// The operation is executed in place on the output if the input is
23+
// set to null. The input for this process is assumed to be at the
24+
// location in the output buffer where it would otherwise be.
25+
std::unique_ptr<transport::UnboundBuffer> inBuffer;
26+
void* inPtr;
27+
size_t inElements;
28+
std::unique_ptr<transport::UnboundBuffer> outBuffer;
29+
void* outPtr;
30+
size_t outElements;
31+
32+
// Number of bytes per element.
33+
size_t elementSize;
34+
35+
// Tag for this gather operation.
36+
// Must be unique across operations executing in parallel.
37+
uint32_t tag;
38+
};
39+
40+
void allgather(const std::shared_ptr<Context>& context, AllgatherOptions& opts);
41+
42+
} // namespace gloo

gloo/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
set(GLOO_TEST_SRCS
2+
"${CMAKE_CURRENT_SOURCE_DIR}/allgather_test.cc"
23
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_builder_test.cc"
34
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_test.cc"
45
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_test.cc"

gloo/test/allgather_test.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <thread>
1212
#include <vector>
1313

14+
#include "gloo/allgather.h"
1415
#include "gloo/allgather_ring.h"
1516
#include "gloo/common/common.h"
1617
#include "gloo/test/base_test.h"
@@ -109,6 +110,67 @@ INSTANTIATE_TEST_CASE_P(
109110
::testing::ValuesIn(genMemorySizes()),
110111
::testing::Range(1, 4)));
111112

113+
using NewParam = std::tuple<int, int, bool>;
114+
115+
class AllgatherNewTest : public BaseTest,
116+
public ::testing::WithParamInterface<NewParam> {};
117+
118+
TEST_P(AllgatherNewTest, Default) {
119+
auto contextSize = std::get<0>(GetParam());
120+
auto dataSize = std::get<1>(GetParam());
121+
auto passBuffers = std::get<2>(GetParam());
122+
123+
auto validate = [dataSize](
124+
const std::shared_ptr<Context>& context,
125+
Fixture<uint64_t>& output) {
126+
const auto ptr = output.getPointer();
127+
const auto stride = context->size;
128+
for (auto j = 0; j < context->size; j++) {
129+
for (auto k = 0; k < dataSize; k++) {
130+
ASSERT_EQ(j + k * stride, ptr[k + j * dataSize])
131+
<< "Mismatch at index " << (k + j * dataSize);
132+
}
133+
}
134+
};
135+
136+
spawn(contextSize, [&](std::shared_ptr<Context> context) {
137+
auto input = Fixture<uint64_t>(context, 1, dataSize);
138+
auto output = Fixture<uint64_t>(context, 1, contextSize * dataSize);
139+
140+
AllgatherOptions opts;
141+
opts.elementSize = sizeof(uint64_t);
142+
143+
if (passBuffers) {
144+
// Run with (optionally cached) unbound buffers in options
145+
opts.inBuffer = context->createUnboundBuffer(
146+
input.getPointer(),
147+
dataSize * sizeof(uint64_t));
148+
opts.outBuffer = context->createUnboundBuffer(
149+
output.getPointer(),
150+
contextSize * dataSize * sizeof(uint64_t));
151+
} else {
152+
// Run with raw pointers and sizes in options
153+
opts.inPtr = input.getPointer();
154+
opts.inElements = dataSize;
155+
opts.outPtr = output.getPointer();
156+
opts.outElements = contextSize * dataSize;
157+
}
158+
159+
input.assignValues();
160+
allgather(context, opts);
161+
validate(context, output);
162+
});
163+
}
164+
165+
INSTANTIATE_TEST_CASE_P(
166+
AllgatherNewDefault,
167+
AllgatherNewTest,
168+
::testing::Combine(
169+
::testing::Values(2, 4, 7),
170+
::testing::ValuesIn(genMemorySizes()),
171+
::testing::Values(false, true)));
172+
173+
112174
} // namespace
113175
} // namespace test
114176
} // namespace gloo

gloo/test/base_test.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ class Fixture {
136136
}
137137
}
138138

139+
void clear() {
140+
for (auto i = 0; i < srcs.size(); i++) {
141+
for (auto j = 0; j < count; j++) {
142+
srcs[i][j] = 0;
143+
}
144+
}
145+
}
146+
139147
void checkBroadcastResult(Fixture<T>& fixture, int root, int rootPointer) {
140148
// Expected is set to the expected value at ptr[0]
141149
const auto expected = root * fixture.srcs.size() + rootPointer;

gloo/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ namespace gloo {
5555
//
5656

5757
constexpr uint8_t kGatherSlotPrefix = 0x01;
58+
constexpr uint8_t kAllgatherSlotPrefix = 0x02;
5859

5960
class Slot {
6061
public:

0 commit comments

Comments
 (0)