Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions include/faabric/mpi/MpiMessageBuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ class MpiMessageBuffer
{
public:
int requestId = -1;
std::shared_ptr<MPIMessage> msg = nullptr;
std::unique_ptr<MPIMessage> msg = nullptr;
int sendRank = -1;
int recvRank = -1;
uint8_t* buffer = nullptr;
faabric_datatype_t* dataType = nullptr;
int count = -1;
MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL;

bool isAcknowledged() { return msg != nullptr; }
bool isAcknowledged() const { return msg != nullptr; }

void acknowledge(std::shared_ptr<MPIMessage> msgIn) { msg = msgIn; }
void acknowledge(std::unique_ptr<MPIMessage> msgIn)
{
msg = std::move(msgIn);
}
};

/* Interface to query the buffer size */
Expand Down
12 changes: 6 additions & 6 deletions include/faabric/mpi/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ namespace faabric::mpi {
// -----------------------------------
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
// as the broker already has mocking capabilities
std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank);
std::vector<MPIMessage> getMpiMockedMessages(int sendRank);

typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
typedef faabric::util::FixedCapacityQueue<std::unique_ptr<MPIMessage>>
InMemoryMpiQueue;

class MpiWorld
Expand Down Expand Up @@ -240,16 +240,16 @@ class MpiWorld
void sendRemoteMpiMessage(std::string dstHost,
int sendRank,
int recvRank,
const std::shared_ptr<MPIMessage>& msg);
const std::unique_ptr<MPIMessage>& msg);

std::shared_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
std::unique_ptr<MPIMessage> recvRemoteMpiMessage(int sendRank,
int recvRank);

// Support for asyncrhonous communications
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer(int sendRank,
int recvRank);

std::shared_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
std::unique_ptr<MPIMessage> recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize = 0);

Expand All @@ -258,7 +258,7 @@ class MpiWorld
void checkRanksRange(int sendRank, int recvRank);

// Abstraction of the bulk of the recv work, shared among various functions
void doRecv(std::shared_ptr<MPIMessage>& m,
void doRecv(std::unique_ptr<MPIMessage> m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
Expand Down
4 changes: 2 additions & 2 deletions src/mpi/MpiMessageBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ MpiMessageIterator MpiMessageBuffer::getRequestPendingMsg(int requestId)
MpiMessageIterator msgIt =
std::find_if(pendingMsgs.begin(),
pendingMsgs.end(),
[requestId](PendingAsyncMpiMessage pendingMsg) {
[requestId](PendingAsyncMpiMessage& pendingMsg) {
return pendingMsg.requestId == requestId;
});

Expand All @@ -48,7 +48,7 @@ MpiMessageIterator MpiMessageBuffer::getFirstNullMsgUntil(
const MpiMessageIterator& msgItEnd)
{
return std::find_if(
pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage pendingMsg) {
pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage& pendingMsg) {
return pendingMsg.msg == nullptr;
});
}
Expand Down
44 changes: 23 additions & 21 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ static std::mutex mockMutex;

// The identifier in this map is the sending rank. For the receiver's rank
// we can inspect the MPIMessage object
static std::map<int, std::vector<std::shared_ptr<MPIMessage>>>
mpiMockedMessages;
static std::map<int, std::vector<MPIMessage>> mpiMockedMessages;

std::vector<std::shared_ptr<MPIMessage>> getMpiMockedMessages(int sendRank)
std::vector<MPIMessage> getMpiMockedMessages(int sendRank)
{
faabric::util::UniqueLock lock(mockMutex);
return mpiMockedMessages[sendRank];
Expand All @@ -52,7 +51,7 @@ MpiWorld::MpiWorld()
void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
int sendRank,
int recvRank,
const std::shared_ptr<MPIMessage>& msg)
const std::unique_ptr<MPIMessage>& msg)
{
std::string serialisedBuffer;
if (!msg->SerializeToString(&serialisedBuffer)) {
Expand All @@ -78,7 +77,7 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
}
}

std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank,
std::unique_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank,
int recvRank)
{
std::vector<uint8_t> msg;
Expand All @@ -95,7 +94,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvRemoteMpiMessage(int sendRank,
throw e;
}
PARSE_MSG(MPIMessage, msg.data(), msg.size());
return std::make_shared<MPIMessage>(parsedMsg);
return std::make_unique<MPIMessage>(parsedMsg);
}

std::shared_ptr<MpiMessageBuffer> MpiWorld::getUnackedMessageBuffer(
Expand Down Expand Up @@ -478,7 +477,7 @@ int MpiWorld::irecv(int sendRank,
assert(!pendingMsg.isAcknowledged());

auto umb = getUnackedMessageBuffer(sendRank, recvRank);
umb->addMessage(pendingMsg);
umb->addMessage(std::move(pendingMsg));

return requestId;
}
Expand Down Expand Up @@ -506,7 +505,7 @@ void MpiWorld::send(int sendRank,
int msgId = (localMsgCount + 1) % INT32_MAX;

// Create the message
auto m = std::make_shared<MPIMessage>();
auto m = std::make_unique<MPIMessage>();
m->set_id(msgId);
m->set_worldid(id);
m->set_sender(sendRank);
Expand All @@ -522,7 +521,7 @@ void MpiWorld::send(int sendRank,

// Mock the message sending in tests
if (faabric::util::isMockMode()) {
mpiMockedMessages[sendRank].push_back(m);
mpiMockedMessages[sendRank].push_back(*m);
return;
}

Expand All @@ -534,7 +533,7 @@ void MpiWorld::send(int sendRank,
} else {
SPDLOG_TRACE(
"MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType);
sendRemoteMpiMessage(otherHost, sendRank, recvRank, m);
sendRemoteMpiMessage(otherHost, sendRank, recvRank, std::move(m));
}

/* 02/05/2022 - The following bit of code fails randomly with a protobuf
Expand Down Expand Up @@ -573,13 +572,13 @@ void MpiWorld::recv(int sendRank,
}

// Recv message from underlying transport
std::shared_ptr<MPIMessage> m = recvBatchReturnLast(sendRank, recvRank);
auto m = recvBatchReturnLast(sendRank, recvRank);

// Do the processing
doRecv(m, buffer, dataType, count, status, messageType);
doRecv(std::move(m), buffer, dataType, count, status, messageType);
}

void MpiWorld::doRecv(std::shared_ptr<MPIMessage>& m,
void MpiWorld::doRecv(std::unique_ptr<MPIMessage> m,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
Expand Down Expand Up @@ -1012,18 +1011,18 @@ void MpiWorld::awaitAsyncRequest(int requestId)
std::list<MpiMessageBuffer::PendingAsyncMpiMessage>::iterator msgIt =
umb->getRequestPendingMsg(requestId);

std::shared_ptr<MPIMessage> m;
std::unique_ptr<MPIMessage> m;
if (msgIt->msg != nullptr) {
// This id has already been acknowledged by a recv call, so do the recv
m = msgIt->msg;
m = std::move(msgIt->msg);
} else {
// We need to acknowledge all messages not acknowledged from the
// begining until us
m = recvBatchReturnLast(
sendRank, recvRank, umb->getTotalUnackedMessagesUntil(msgIt) + 1);
}

doRecv(m,
doRecv(std::move(m),
msgIt->buffer,
msgIt->dataType,
msgIt->count,
Expand Down Expand Up @@ -1395,15 +1394,18 @@ void MpiWorld::allToAll(int rank,
// queues.
void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)
{
throw std::runtime_error("MPI probe is not supported!");
/*
const std::shared_ptr<InMemoryMpiQueue>& queue =
getLocalQueue(sendRank, recvRank);
// 30/12/21 - Peek will throw a runtime error

std::shared_ptr<MPIMessage> m = *(queue->peek());

faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type());
status->bytesSize = m->count() * datatype->size;
status->MPI_ERROR = 0;
status->MPI_SOURCE = m->sender();
*/
}

void MpiWorld::barrier(int thisRank)
Expand Down Expand Up @@ -1456,7 +1458,7 @@ void MpiWorld::initLocalQueues()
}
}

std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
std::unique_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
int recvRank,
int batchSize)
{
Expand All @@ -1478,7 +1480,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,
// Recv message: first we receive all messages for which there is an id
// in the unacknowleged buffer but no msg. Note that these messages
// (batchSize - 1) were `irecv`-ed before ours.
std::shared_ptr<MPIMessage> ourMsg;
std::unique_ptr<MPIMessage> ourMsg;
auto msgIt = umb->getFirstNullMsg();
if (isLocal) {
// First receive messages that happened before us
Expand All @@ -1489,7 +1491,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,

// Put the unacked message in the UMB
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
msgIt->acknowledge(std::move(pendingMsg));
msgIt++;
} catch (faabric::util::QueueTimeoutException& e) {
SPDLOG_ERROR(
Expand Down Expand Up @@ -1525,7 +1527,7 @@ std::shared_ptr<MPIMessage> MpiWorld::recvBatchReturnLast(int sendRank,

// Put the unacked message in the UMB
assert(!msgIt->isAcknowledged());
msgIt->acknowledge(pendingMsg);
msgIt->acknowledge(std::move(pendingMsg));
msgIt++;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test/mpi/test_mpi_message_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments(
pendingMsg.requestId = requestId;

if (!nullMsg) {
pendingMsg.msg = std::make_shared<MPIMessage>();
pendingMsg.msg = std::make_unique<MPIMessage>();
}

return pendingMsg;
Expand Down
10 changes: 4 additions & 6 deletions tests/test/mpi/test_remote_mpi_worlds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ using namespace faabric::mpi;
using namespace faabric::scheduler;

namespace tests {
std::set<int> getReceiversFromMessages(
std::vector<std::shared_ptr<MPIMessage>> msgs)
std::set<int> getReceiversFromMessages(std::vector<MPIMessage> msgs)
{
std::set<int> receivers;
for (const auto& msg : msgs) {
receivers.insert(msg->destination());
receivers.insert(msg.destination());
}

return receivers;
Expand Down Expand Up @@ -219,12 +218,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture,
thisWorld.destroy();
}

std::set<int> getMsgCountsFromMessages(
std::vector<std::shared_ptr<MPIMessage>> msgs)
std::set<int> getMsgCountsFromMessages(std::vector<MPIMessage> msgs)
{
std::set<int> counts;
for (const auto& msg : msgs) {
counts.insert(msg->count());
counts.insert(msg.count());
}

return counts;
Expand Down