Skip to content

Commit 578c079

Browse files
committed
mpi: #385 and #379
1 parent 7483943 commit 578c079

File tree

15 files changed

+407
-237
lines changed

15 files changed

+407
-237
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ jobs:
105105
if: github.event.pull_request.draft == false
106106
needs: [conan-cache]
107107
runs-on: ubuntu-latest
108+
timeout-minutes: 20
108109
strategy:
109110
fail-fast: false
110111
matrix:
@@ -139,6 +140,7 @@ jobs:
139140
if: github.event.pull_request.draft == false
140141
needs: [conan-cache]
141142
runs-on: ubuntu-latest
143+
timeout-minutes: 20
142144
env:
143145
CONAN_CACHE_MOUNT_SOURCE: ~/.conan/
144146
steps:

include/faabric/mpi/MpiMessage.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <vector>
5+
6+
namespace faabric::mpi {
7+
8+
enum MpiMessageType : int32_t
9+
{
10+
NORMAL = 0,
11+
BARRIER_JOIN = 1,
12+
BARRIER_DONE = 2,
13+
SCATTER = 3,
14+
GATHER = 4,
15+
ALLGATHER = 5,
16+
REDUCE = 6,
17+
SCAN = 7,
18+
ALLREDUCE = 8,
19+
ALLTOALL = 9,
20+
SENDRECV = 10,
21+
BROADCAST = 11,
22+
};
23+
24+
struct MpiMessage
25+
{
26+
int32_t id;
27+
int32_t worldId;
28+
int32_t sendRank;
29+
int32_t recvRank;
30+
int32_t typeSize;
31+
int32_t count;
32+
MpiMessageType messageType;
33+
void* buffer;
34+
};
35+
36+
inline size_t payloadSize(const MpiMessage& msg)
37+
{
38+
return msg.typeSize * msg.count;
39+
}
40+
41+
inline size_t msgSize(const MpiMessage& msg)
42+
{
43+
return sizeof(MpiMessage) + payloadSize(msg);
44+
}
45+
46+
void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);
47+
48+
void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg);
49+
}

include/faabric/mpi/MpiMessageBuffer.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
#include <faabric/mpi/MpiMessage.h>
12
#include <faabric/mpi/mpi.h>
2-
#include <faabric/mpi/mpi.pb.h>
33

44
#include <iterator>
55
#include <list>
6+
#include <memory>
67

78
namespace faabric::mpi {
89
/* The MPI message buffer (MMB) keeps track of the asyncrhonous
@@ -25,17 +26,20 @@ class MpiMessageBuffer
2526
{
2627
public:
2728
int requestId = -1;
28-
std::shared_ptr<MPIMessage> msg = nullptr;
29+
std::shared_ptr<MpiMessage> msg = nullptr;
2930
int sendRank = -1;
3031
int recvRank = -1;
3132
uint8_t* buffer = nullptr;
3233
faabric_datatype_t* dataType = nullptr;
3334
int count = -1;
34-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL;
35+
MpiMessageType messageType = MpiMessageType::NORMAL;
3536

3637
bool isAcknowledged() { return msg != nullptr; }
3738

38-
void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; }
39+
void acknowledge(const MpiMessage& msgIn)
40+
{
41+
msg = std::make_shared<MpiMessage>(msgIn);
42+
}
3943
};
4044

4145
/* Interface to query the buffer size */

include/faabric/mpi/MpiWorld.h

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3+
#include <faabric/mpi/MpiMessage.h>
34
#include <faabric/mpi/MpiMessageBuffer.h>
45
#include <faabric/mpi/mpi.h>
5-
#include <faabric/mpi/mpi.pb.h>
66
#include <faabric/proto/faabric.pb.h>
77
#include <faabric/scheduler/InMemoryMessageQueue.h>
88
#include <faabric/transport/PointToPointBroker.h>
@@ -26,10 +26,9 @@ namespace faabric::mpi {
2626
// -----------------------------------
2727
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
2828
// as the broker already has mocking capabilities
29-
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);
29+
std::vector<MpiMessage> getMpiMockedMessages(int sendRank);
3030

31-
typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
32-
InMemoryMpiQueue;
31+
typedef faabric::util::SpinLockQueue<MpiMessage> InMemoryMpiQueue;
3332

3433
class MpiWorld
3534
{
@@ -73,36 +72,36 @@ class MpiWorld
7372
const uint8_t* buffer,
7473
faabric_datatype_t* dataType,
7574
int count,
76-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
75+
MpiMessageType messageType = MpiMessageType::NORMAL);
7776

7877
int isend(int sendRank,
7978
int recvRank,
8079
const uint8_t* buffer,
8180
faabric_datatype_t* dataType,
8281
int count,
83-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
82+
MpiMessageType messageType = MpiMessageType::NORMAL);
8483

8584
void broadcast(int rootRank,
8685
int thisRank,
8786
uint8_t* buffer,
8887
faabric_datatype_t* dataType,
8988
int count,
90-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
89+
MpiMessageType messageType = MpiMessageType::NORMAL);
9190

9291
void recv(int sendRank,
9392
int recvRank,
9493
uint8_t* buffer,
9594
faabric_datatype_t* dataType,
9695
int count,
9796
MPI_Status* status,
98-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
97+
MpiMessageType messageType = MpiMessageType::NORMAL);
9998

10099
int irecv(int sendRank,
101100
int recvRank,
102101
uint8_t* buffer,
103102
faabric_datatype_t* dataType,
104103
int count,
105-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
104+
MpiMessageType messageType = MpiMessageType::NORMAL);
106105

107106
void awaitAsyncRequest(int requestId);
108107

@@ -185,8 +184,6 @@ class MpiWorld
185184

186185
std::shared_ptr<InMemoryMpiQueue> getLocalQueue(int sendRank, int recvRank);
187186

188-
long getLocalQueueSize(int sendRank, int recvRank);
189-
190187
void overrideHost(const std::string& newHost);
191188

192189
double getWTime();
@@ -240,29 +237,36 @@ class MpiWorld
240237
void sendRemoteMpiMessage(std::string dstHost,
241238
int sendRank,
242239
int recvRank,
243-
const std::shared_ptr<MPIMessage>& msg);
240+
const MpiMessage& msg);
244241

245-
std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
246-
int recvRank);
242+
MpiMessage recvRemoteMpiMessage(int sendRank, int recvRank);
247243

248244
// Support for asyncrhonous communications
249245
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
250246
int recvRank);
251247

252-
std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
253-
int recvRank,
254-
int batchSize = 0);
248+
MpiMessage recvBatchReturnLast(int sendRank,
249+
int recvRank,
250+
int batchSize = 0);
255251

256252
/* Helper methods */
257253

258254
void checkRanksRange(int sendRank, int recvRank);
259255

260256
// Abstraction of the bulk of the recv work, shared among various functions
261-
void doRecv(std::shared_ptr<MPIMessage>& m,
257+
void doRecv(const MpiMessage& m,
258+
uint8_t* buffer,
259+
faabric_datatype_t* dataType,
260+
int count,
261+
MPI_Status* status,
262+
MpiMessageType messageType = MpiMessageType::NORMAL);
263+
264+
// Abstraction of the bulk of the recv work, shared among various functions
265+
void doRecv(std::unique_ptr<MpiMessage> m,
262266
uint8_t* buffer,
263267
faabric_datatype_t* dataType,
264268
int count,
265269
MPI_Status* status,
266-
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL);
270+
MpiMessageType messageType = MpiMessageType::NORMAL);
267271
};
268272
}

include/faabric/util/queue.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <faabric/util/locks.h>
55
#include <faabric/util/logging.h>
66

7+
#include <boost/lockfree/spsc_queue.hpp>
78
#include <condition_variable>
89
#include <queue>
910
#include <readerwriterqueue/readerwritercircularbuffer.h>
@@ -215,6 +216,48 @@ class FixedCapacityQueue
215216
moodycamel::BlockingReaderWriterCircularBuffer<T> mq;
216217
};
217218

219+
// High-performance, spin-lock single-producer, single-consumer queue. This
220+
// queue spin-locks, so use at your own risk!
221+
template<typename T>
222+
class SpinLockQueue
223+
{
224+
public:
225+
void enqueue(T& value, long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
226+
{
227+
while (!mq.push(value)) {
228+
;
229+
};
230+
}
231+
232+
T dequeue(long timeoutMs = DEFAULT_QUEUE_TIMEOUT_MS)
233+
{
234+
T value;
235+
236+
while (!mq.pop(value)) {
237+
;
238+
}
239+
240+
return value;
241+
}
242+
243+
long size()
244+
{
245+
throw std::runtime_error("Size for fast queue unimplemented!");
246+
}
247+
248+
void drain()
249+
{
250+
while (mq.pop()) {
251+
;
252+
}
253+
}
254+
255+
void reset() { ; }
256+
257+
private:
258+
boost::lockfree::spsc_queue<T, boost::lockfree::capacity<1024>> mq;
259+
};
260+
218261
class TokenPool
219262
{
220263
public:

src/mpi/CMakeLists.txt

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,12 @@ endif()
3838
# -----------------------------------------------
3939

4040
if (NOT ("${CMAKE_PROJECT_NAME}" STREQUAL "faabricmpi"))
41-
# Generate protobuf headers
42-
set(MPI_PB_HEADER_COPIED "${FAABRIC_INCLUDE_DIR}/faabric/mpi/mpi.pb.h")
43-
44-
protobuf_generate_cpp(MPI_PB_SRC MPI_PB_HEADER mpi.proto)
45-
46-
# Copy the generated headers into place
47-
add_custom_command(
48-
OUTPUT "${MPI_PB_HEADER_COPIED}"
49-
DEPENDS "${MPI_PB_HEADER}"
50-
COMMAND ${CMAKE_COMMAND}
51-
ARGS -E copy ${MPI_PB_HEADER} ${FAABRIC_INCLUDE_DIR}/faabric/mpi/
52-
)
53-
54-
add_custom_target(
55-
mpi_pbh_copied
56-
DEPENDS ${MPI_PB_HEADER_COPIED}
57-
)
58-
59-
add_dependencies(faabric_common_dependencies mpi_pbh_copied)
60-
6141
faabric_lib(mpi
6242
MpiContext.cpp
43+
MpiMessage.cpp
6344
MpiMessageBuffer.cpp
6445
MpiWorld.cpp
6546
MpiWorldRegistry.cpp
66-
${MPI_PB_SRC}
6747
)
6848

6949
target_link_libraries(mpi PRIVATE

src/mpi/MpiMessage.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <faabric/mpi/MpiMessage.h>
2+
#include <faabric/util/memory.h>
3+
4+
#include <cassert>
5+
#include <cstdint>
6+
#include <cstring>
7+
8+
namespace faabric::mpi {
9+
10+
void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
11+
{
12+
assert(msg != nullptr);
13+
assert(bytes.size() >= sizeof(MpiMessage));
14+
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
15+
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
16+
assert(thisPayloadSize == payloadSize(*msg));
17+
18+
if (thisPayloadSize == 0) {
19+
msg->buffer = nullptr;
20+
return;
21+
}
22+
23+
msg->buffer = faabric::util::malloc(thisPayloadSize);
24+
std::memcpy(
25+
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
26+
}
27+
28+
void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
29+
{
30+
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
31+
size_t payloadSz = payloadSize(msg);
32+
if (payloadSz > 0 && msg.buffer != nullptr) {
33+
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)