Skip to content

Commit 86f1bcb

Browse files
pieternfacebook-github-bot
authored andcommitted
Add gather collective (#136)
Summary: This is the first "new style" collective that doesn't use the pattern where it is initialized once, holds some state, and is used many times. The new style is intended for use where collectives can be called at any time, such that initialization can no longer be amortized because reuse may no longer apply. Pull Request resolved: #136 Reviewed By: teng-li Differential Revision: D10249938 Pulled By: pietern fbshipit-source-id: c6f5d134c685247ad37e678e1430e30d6e5328b9
1 parent a94e6bf commit 86f1bcb

File tree

8 files changed

+303
-0
lines changed

8 files changed

+303
-0
lines changed

gloo/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ list(APPEND GLOO_SRCS
99
"${CMAKE_CURRENT_SOURCE_DIR}/algorithm.cc"
1010
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc"
1111
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
12+
"${CMAKE_CURRENT_SOURCE_DIR}/gather.cc"
13+
"${CMAKE_CURRENT_SOURCE_DIR}/types.cc"
1214
)
1315

1416
list(APPEND GLOO_HDRS
@@ -23,6 +25,7 @@ list(APPEND GLOO_HDRS
2325
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_all.h"
2426
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_one.h"
2527
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast_one_to_all.h"
28+
"${CMAKE_CURRENT_SOURCE_DIR}/gather.h"
2629
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_scatter.h"
2730
"${CMAKE_CURRENT_SOURCE_DIR}/context.h"
2831
"${CMAKE_CURRENT_SOURCE_DIR}/math.h"

gloo/gather.cc

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree. An additional grant
7+
* of patent rights can be found in the PATENTS file in the same directory.
8+
*/
9+
10+
#include "gloo/gather.h"
11+
12+
#include <cstring>
13+
14+
#include "gloo/common/logging.h"
15+
#include "gloo/types.h"
16+
17+
namespace gloo {
18+
19+
void gather(const std::shared_ptr<Context>& context, GatherOptions& opts) {
20+
std::unique_ptr<transport::UnboundBuffer> tmpInBuffer;
21+
std::unique_ptr<transport::UnboundBuffer> tmpOutBuffer;
22+
transport::UnboundBuffer* in = nullptr;
23+
transport::UnboundBuffer* out = nullptr;
24+
const auto slot = Slot::build(kGatherSlotPrefix, opts.tag);
25+
26+
// Sanity checks
27+
GLOO_ENFORCE(opts.elementSize > 0);
28+
29+
// Figure out pointer to input buffer
30+
if (opts.inBuffer) {
31+
in = opts.inBuffer.get();
32+
} else {
33+
GLOO_ENFORCE(opts.inPtr != nullptr);
34+
GLOO_ENFORCE(opts.inElements > 0);
35+
tmpInBuffer = context->createUnboundBuffer(
36+
opts.inPtr, opts.inElements * opts.elementSize);
37+
in = tmpInBuffer.get();
38+
}
39+
40+
if (context->rank == opts.root) {
41+
const size_t chunkSize = in->size;
42+
43+
// Figure out pointer to output buffer (only for root rank)
44+
if (opts.outBuffer) {
45+
out = opts.outBuffer.get();
46+
} else {
47+
GLOO_ENFORCE(opts.outPtr != nullptr);
48+
GLOO_ENFORCE(opts.outElements > 0);
49+
tmpOutBuffer = context->createUnboundBuffer(
50+
opts.outPtr, opts.outElements * opts.elementSize);
51+
out = tmpOutBuffer.get();
52+
}
53+
54+
// Ensure the output buffer has the right size.
55+
GLOO_ENFORCE(in->size * context->size == out->size);
56+
57+
// Post receive operations from peers into out buffer
58+
for (size_t i = 0; i < context->size; i++) {
59+
if (i == context->rank) {
60+
continue;
61+
}
62+
out->recv(i, slot, i * chunkSize, chunkSize);
63+
}
64+
65+
// Copy local input to output
66+
memcpy(
67+
static_cast<char*>(out->ptr) + (context->rank * chunkSize),
68+
in->ptr,
69+
chunkSize);
70+
71+
// Wait for receive operations to complete
72+
for (size_t i = 0; i < context->size; i++) {
73+
if (i == context->rank) {
74+
continue;
75+
}
76+
out->waitRecv();
77+
}
78+
} else {
79+
in->send(opts.root, slot);
80+
in->waitSend();
81+
}
82+
}
83+
84+
} // namespace gloo

gloo/gather.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree. An additional grant
7+
* of patent rights can be found in the PATENTS file in the same directory.
8+
*/
9+
10+
#pragma once
11+
12+
#include "gloo/context.h"
13+
#include "gloo/transport/unbound_buffer.h"
14+
15+
namespace gloo {
16+
17+
struct GatherOptions {
18+
// The input and output buffers can either be specified as an unbound
19+
// buffer (that can be cached and reused by the caller), or a
20+
// literal pointer and number of elements stored at that pointer.
21+
std::unique_ptr<transport::UnboundBuffer> inBuffer;
22+
void* inPtr;
23+
size_t inElements;
24+
std::unique_ptr<transport::UnboundBuffer> outBuffer;
25+
void* outPtr;
26+
size_t outElements;
27+
28+
// Number of bytes per element.
29+
size_t elementSize;
30+
31+
// Rank of receiving process.
32+
int root;
33+
34+
// Tag for this gather operation.
35+
// Must be unique across operations executing in parallel.
36+
uint32_t tag;
37+
};
38+
39+
void gather(const std::shared_ptr<Context>& context, GatherOptions& opts);
40+
41+
} // namespace gloo

gloo/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set(GLOO_TEST_SRCS
44
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_test.cc"
55
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast_builder_test.cc"
66
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast_test.cc"
7+
"${CMAKE_CURRENT_SOURCE_DIR}/gather_test.cc"
78
"${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc"
89
"${CMAKE_CURRENT_SOURCE_DIR}/main.cc"
910
"${CMAKE_CURRENT_SOURCE_DIR}/send_recv_test.cc"

gloo/test/base_test.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ class Fixture {
172172
}
173173
}
174174

175+
T* getPointer() const {
176+
return srcs.front().get();
177+
}
178+
175179
std::vector<T*> getPointers() const {
176180
std::vector<T*> out;
177181
for (const auto& src : srcs) {

gloo/test/gather_test.cc

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree. An additional grant
7+
* of patent rights can be found in the PATENTS file in the same directory.
8+
*/
9+
10+
#include "gloo/gather.h"
11+
#include "gloo/test/base_test.h"
12+
13+
namespace gloo {
14+
namespace test {
15+
namespace {
16+
17+
// Test parameterization.
18+
using Param = std::tuple<int, size_t>;
19+
20+
// Test fixture.
21+
class GatherTest : public BaseTest,
22+
public ::testing::WithParamInterface<Param> {
23+
};
24+
25+
TEST_P(GatherTest, Default) {
26+
auto contextSize = std::get<0>(GetParam());
27+
auto dataSize = std::get<1>(GetParam());
28+
29+
spawn(contextSize, [&](std::shared_ptr<Context> context) {
30+
auto input = Fixture<uint64_t>(context, 1, dataSize);
31+
auto output = Fixture<uint64_t>(context, 1, contextSize * dataSize);
32+
33+
// Initialize fixture with globally unique values
34+
input.assignValues();
35+
36+
GatherOptions opts;
37+
opts.inPtr = input.getPointer();
38+
opts.inElements = dataSize;
39+
opts.elementSize = sizeof(uint64_t);
40+
41+
// Take turns being root
42+
for (auto i = 0; i < context->size; i++) {
43+
// Set output pointer only when root
44+
if (i == context->rank) {
45+
opts.outPtr = output.getPointer();
46+
opts.outElements = dataSize * contextSize;
47+
} else {
48+
opts.outPtr = nullptr;
49+
opts.outElements = 0;
50+
}
51+
52+
opts.root = i;
53+
gather(context, opts);
54+
55+
// Validate result if root
56+
if (i == context->rank) {
57+
const auto ptr = output.getPointer();
58+
const auto stride = context->size;
59+
for (auto j = 0; j < context->size; j++) {
60+
for (auto k = 0; k < dataSize; k++) {
61+
ASSERT_EQ(j + k * stride, ptr[k + j * dataSize])
62+
<< "Mismatch at index " << (k + j * dataSize);
63+
}
64+
}
65+
}
66+
}
67+
});
68+
}
69+
70+
std::vector<size_t> genMemorySizes() {
71+
std::vector<size_t> v;
72+
v.push_back(1);
73+
v.push_back(10);
74+
v.push_back(100);
75+
v.push_back(1000);
76+
return v;
77+
}
78+
79+
INSTANTIATE_TEST_CASE_P(
80+
GatherDefault,
81+
GatherTest,
82+
::testing::Combine(
83+
::testing::Values(2, 4, 7),
84+
::testing::ValuesIn(genMemorySizes())));
85+
86+
} // namespace
87+
} // namespace test
88+
} // namespace gloo

gloo/types.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree. An additional grant
7+
* of patent rights can be found in the PATENTS file in the same directory.
8+
*/
9+
10+
#include "gloo/types.h"
11+
12+
#include <stdexcept>
13+
14+
namespace gloo {
15+
16+
Slot Slot::build(uint8_t prefix, uint32_t tag) {
17+
uint64_t u64prefix = ((uint64_t)prefix) << 56;
18+
uint64_t u64tag = (((uint64_t)tag) & 0xffffffff) << 24;
19+
return Slot(u64prefix || u64tag, 0);
20+
}
21+
22+
Slot Slot::operator+(uint8_t i) const {
23+
// Maximum of 8 bits for use in a single collective operation.
24+
// To avoid conflicts between them, raise if it overflows.
25+
auto delta = delta_ + i;
26+
if (delta > 0xff) {
27+
throw std::runtime_error(
28+
"Slot overflow: delta " + std::to_string(delta) + " > 0xff");
29+
}
30+
31+
return Slot(base_, delta);
32+
}
33+
34+
} // namespace gloo

gloo/types.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,54 @@
2727

2828
namespace gloo {
2929

30+
// Unlike old style collectives that are class instances that hold
31+
// some state, the new style collectives do not need initialization
32+
// before they can run. Instead of asking the context for a series of
33+
// slots and storing them for later use and reuse, the new style
34+
// collectives take a slot (or tag) argument that allows for
35+
// concurrent execution of multiple collectives on the same context.
36+
//
37+
// This tag is what determines the slot numbers for the send and recv
38+
// operations that the collectives end up executing. A single
39+
// collective may have many send and recv operations running in
40+
// parallel, so instead of using the specified tag verbatim, we use it
41+
// as a prefix. Also, to avoid conflicts between collectives with the
42+
// same tag, we have another tag prefix per collective type. Out of
43+
// the 64 bits we can use for a slot, we use 8 of them to identify a
44+
// collective, 32 to identify the collective tag, another 8 for use by
45+
// the collective operation itself (allowing for 256 independent send
46+
// and recv operations against the same point to point pair), and
47+
// leave 16 bits unused.
48+
//
49+
// Below, you find constexprs for the prefix per collective type, as
50+
// well as a way to compute slots when executing a collective. The
51+
// slot class below captures both a prefix and a delta on that prefix
52+
// to support addition with bounds checking. It is usable as an
53+
// uint64_t, but one that cannot overflow beyond the bits allocated
54+
// for use within a collective.
55+
//
56+
57+
constexpr uint8_t kGatherSlotPrefix = 0x01;
58+
59+
class Slot {
60+
public:
61+
static Slot build(uint8_t prefix, uint32_t tag);
62+
63+
operator uint64_t() const {
64+
return base_ + delta_;
65+
}
66+
67+
Slot operator+(uint8_t i) const;
68+
69+
protected:
70+
explicit Slot(uint64_t base, uint64_t delta)
71+
: base_(base), delta_(delta) {
72+
}
73+
74+
const uint64_t base_;
75+
const uint64_t delta_;
76+
};
77+
3078
struct float16;
3179
float16 cpu_float2half_rn(float f);
3280
float cpu_half2float(float16 h);

0 commit comments

Comments
 (0)