diff --git a/include/faabric/mpi/MpiMessageBuffer.h b/include/faabric/mpi/MpiMessageBuffer.h index 9fc67b644..4038d022e 100644 --- a/include/faabric/mpi/MpiMessageBuffer.h +++ b/include/faabric/mpi/MpiMessageBuffer.h @@ -25,7 +25,7 @@ class MpiMessageBuffer { public: int requestId = -1; - std::shared_ptr msg = nullptr; + std::unique_ptr msg = nullptr; int sendRank = -1; int recvRank = -1; uint8_t* buffer = nullptr; @@ -33,9 +33,12 @@ class MpiMessageBuffer int count = -1; MPIMessage::MPIMessageType messageType = MPIMessage::NORMAL; - bool isAcknowledged() { return msg != nullptr; } + bool isAcknowledged() const { return msg != nullptr; } - void acknowledge(std::shared_ptr msgIn) { msg = msgIn; } + void acknowledge(std::unique_ptr msgIn) + { + msg = std::move(msgIn); + } }; /* Interface to query the buffer size */ diff --git a/include/faabric/mpi/MpiWorld.h b/include/faabric/mpi/MpiWorld.h index adee54137..184393393 100644 --- a/include/faabric/mpi/MpiWorld.h +++ b/include/faabric/mpi/MpiWorld.h @@ -26,9 +26,9 @@ namespace faabric::mpi { // ----------------------------------- // MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker // as the broker already has mocking capabilities -std::vector> getMpiMockedMessages(int sendRank); +std::vector getMpiMockedMessages(int sendRank); -typedef faabric::util::FixedCapacityQueue> +typedef faabric::util::FixedCapacityQueue> InMemoryMpiQueue; class MpiWorld @@ -240,16 +240,16 @@ class MpiWorld void sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr& msg); + const std::unique_ptr& msg); - std::shared_ptr recvRemoteMpiMessage(int sendRank, + std::unique_ptr recvRemoteMpiMessage(int sendRank, int recvRank); // Support for asyncrhonous communications std::shared_ptr getUnackedMessageBuffer(int sendRank, int recvRank); - std::shared_ptr recvBatchReturnLast(int sendRank, + std::unique_ptr recvBatchReturnLast(int sendRank, int recvRank, int batchSize = 0); @@ -258,7 +258,7 @@ class MpiWorld void checkRanksRange(int sendRank, int recvRank); // Abstraction of the bulk of the recv work, shared among various functions - void doRecv(std::shared_ptr& m, + void doRecv(std::unique_ptr m, uint8_t* buffer, faabric_datatype_t* dataType, int count, diff --git a/src/mpi/MpiMessageBuffer.cpp b/src/mpi/MpiMessageBuffer.cpp index f8f2663b5..0da9003ee 100644 --- a/src/mpi/MpiMessageBuffer.cpp +++ b/src/mpi/MpiMessageBuffer.cpp @@ -31,7 +31,7 @@ MpiMessageIterator MpiMessageBuffer::getRequestPendingMsg(int requestId) MpiMessageIterator msgIt = std::find_if(pendingMsgs.begin(), pendingMsgs.end(), - [requestId](PendingAsyncMpiMessage pendingMsg) { + [requestId](PendingAsyncMpiMessage& pendingMsg) { return pendingMsg.requestId == requestId; }); @@ -48,7 +48,7 @@ MpiMessageIterator MpiMessageBuffer::getFirstNullMsgUntil( const MpiMessageIterator& msgItEnd) { return std::find_if( - pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage pendingMsg) { + pendingMsgs.begin(), msgItEnd, [](PendingAsyncMpiMessage& pendingMsg) { return pendingMsg.msg == nullptr; }); } diff --git a/src/mpi/MpiWorld.cpp b/src/mpi/MpiWorld.cpp index cda95ed8e..30ffc4f54 100644 --- a/src/mpi/MpiWorld.cpp +++ b/src/mpi/MpiWorld.cpp @@ -33,10 +33,9 @@ static std::mutex mockMutex; // The identifier in this map is the sending rank. For the receiver's rank // we can inspect the MPIMessage object -static std::map>> - mpiMockedMessages; +static std::map> mpiMockedMessages; -std::vector> getMpiMockedMessages(int sendRank) +std::vector getMpiMockedMessages(int sendRank) { faabric::util::UniqueLock lock(mockMutex); return mpiMockedMessages[sendRank]; @@ -52,7 +51,7 @@ MpiWorld::MpiWorld() void MpiWorld::sendRemoteMpiMessage(std::string dstHost, int sendRank, int recvRank, - const std::shared_ptr& msg) + const std::unique_ptr& msg) { std::string serialisedBuffer; if (!msg->SerializeToString(&serialisedBuffer)) { @@ -78,7 +77,7 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost, } } -std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, +std::unique_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank) { std::vector msg; @@ -95,7 +94,7 @@ std::shared_ptr MpiWorld::recvRemoteMpiMessage(int sendRank, throw e; } PARSE_MSG(MPIMessage, msg.data(), msg.size()); - return std::make_shared(parsedMsg); + return std::make_unique(parsedMsg); } std::shared_ptr MpiWorld::getUnackedMessageBuffer( @@ -478,7 +477,7 @@ int MpiWorld::irecv(int sendRank, assert(!pendingMsg.isAcknowledged()); auto umb = getUnackedMessageBuffer(sendRank, recvRank); - umb->addMessage(pendingMsg); + umb->addMessage(std::move(pendingMsg)); return requestId; } @@ -506,7 +505,7 @@ void MpiWorld::send(int sendRank, int msgId = (localMsgCount + 1) % INT32_MAX; // Create the message - auto m = std::make_shared(); + auto m = std::make_unique(); m->set_id(msgId); m->set_worldid(id); m->set_sender(sendRank); @@ -522,7 +521,7 @@ void MpiWorld::send(int sendRank, // Mock the message sending in tests if (faabric::util::isMockMode()) { - mpiMockedMessages[sendRank].push_back(m); + mpiMockedMessages[sendRank].push_back(*m); return; } @@ -534,7 +533,7 @@ void MpiWorld::send(int sendRank, } else { SPDLOG_TRACE( "MPI - send remote {} -> {} ({})", sendRank, recvRank, messageType); - sendRemoteMpiMessage(otherHost, sendRank, recvRank, m); + sendRemoteMpiMessage(otherHost, sendRank, recvRank, std::move(m)); } /* 02/05/2022 - The following bit of code fails randomly with a protobuf @@ -573,13 +572,13 @@ void MpiWorld::recv(int sendRank, } // Recv message from underlying transport - std::shared_ptr m = recvBatchReturnLast(sendRank, recvRank); + auto m = recvBatchReturnLast(sendRank, recvRank); // Do the processing - doRecv(m, buffer, dataType, count, status, messageType); + doRecv(std::move(m), buffer, dataType, count, status, messageType); } -void MpiWorld::doRecv(std::shared_ptr& m, +void MpiWorld::doRecv(std::unique_ptr m, uint8_t* buffer, faabric_datatype_t* dataType, int count, @@ -1012,10 +1011,10 @@ void MpiWorld::awaitAsyncRequest(int requestId) std::list::iterator msgIt = umb->getRequestPendingMsg(requestId); - std::shared_ptr m; + std::unique_ptr m; if (msgIt->msg != nullptr) { // This id has already been acknowledged by a recv call, so do the recv - m = msgIt->msg; + m = std::move(msgIt->msg); } else { // We need to acknowledge all messages not acknowledged from the // begining until us @@ -1023,7 +1022,7 @@ void MpiWorld::awaitAsyncRequest(int requestId) sendRank, recvRank, umb->getTotalUnackedMessagesUntil(msgIt) + 1); } - doRecv(m, + doRecv(std::move(m), msgIt->buffer, msgIt->dataType, msgIt->count, @@ -1395,15 +1394,18 @@ void MpiWorld::allToAll(int rank, // queues. void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status) { + throw std::runtime_error("MPI probe is not supported!"); + /* const std::shared_ptr& queue = getLocalQueue(sendRank, recvRank); - // 30/12/21 - Peek will throw a runtime error + std::shared_ptr m = *(queue->peek()); faabric_datatype_t* datatype = getFaabricDatatypeFromId(m->type()); status->bytesSize = m->count() * datatype->size; status->MPI_ERROR = 0; status->MPI_SOURCE = m->sender(); + */ } void MpiWorld::barrier(int thisRank) @@ -1456,7 +1458,7 @@ void MpiWorld::initLocalQueues() } } -std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, +std::unique_ptr MpiWorld::recvBatchReturnLast(int sendRank, int recvRank, int batchSize) { @@ -1478,7 +1480,7 @@ std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, // Recv message: first we receive all messages for which there is an id // in the unacknowleged buffer but no msg. Note that these messages // (batchSize - 1) were `irecv`-ed before ours. - std::shared_ptr ourMsg; + std::unique_ptr ourMsg; auto msgIt = umb->getFirstNullMsg(); if (isLocal) { // First receive messages that happened before us @@ -1489,7 +1491,7 @@ std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, // Put the unacked message in the UMB assert(!msgIt->isAcknowledged()); - msgIt->acknowledge(pendingMsg); + msgIt->acknowledge(std::move(pendingMsg)); msgIt++; } catch (faabric::util::QueueTimeoutException& e) { SPDLOG_ERROR( @@ -1525,7 +1527,7 @@ std::shared_ptr MpiWorld::recvBatchReturnLast(int sendRank, // Put the unacked message in the UMB assert(!msgIt->isAcknowledged()); - msgIt->acknowledge(pendingMsg); + msgIt->acknowledge(std::move(pendingMsg)); msgIt++; } diff --git a/tests/test/mpi/test_mpi_message_buffer.cpp b/tests/test/mpi/test_mpi_message_buffer.cpp index 1674172fd..1184150e2 100644 --- a/tests/test/mpi/test_mpi_message_buffer.cpp +++ b/tests/test/mpi/test_mpi_message_buffer.cpp @@ -21,7 +21,7 @@ MpiMessageBuffer::PendingAsyncMpiMessage genRandomArguments( pendingMsg.requestId = requestId; if (!nullMsg) { - pendingMsg.msg = std::make_shared(); + pendingMsg.msg = std::make_unique(); } return pendingMsg; diff --git a/tests/test/mpi/test_remote_mpi_worlds.cpp b/tests/test/mpi/test_remote_mpi_worlds.cpp index 1e56b48b1..13dd51707 100644 --- a/tests/test/mpi/test_remote_mpi_worlds.cpp +++ b/tests/test/mpi/test_remote_mpi_worlds.cpp @@ -21,12 +21,11 @@ using namespace faabric::mpi; using namespace faabric::scheduler; namespace tests { -std::set getReceiversFromMessages( - std::vector> msgs) +std::set getReceiversFromMessages(std::vector msgs) { std::set receivers; for (const auto& msg : msgs) { - receivers.insert(msg->destination()); + receivers.insert(msg.destination()); } return receivers; @@ -219,12 +218,11 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, thisWorld.destroy(); } -std::set getMsgCountsFromMessages( - std::vector> msgs) +std::set getMsgCountsFromMessages(std::vector msgs) { std::set counts; for (const auto& msg : msgs) { - counts.insert(msg->count()); + counts.insert(msg.count()); } return counts;