1
1
#pragma once
2
2
3
+ #include < faabric/mpi/MpiMessage.h>
3
4
#include < faabric/mpi/MpiMessageBuffer.h>
4
5
#include < faabric/mpi/mpi.h>
5
- #include < faabric/mpi/mpi.pb.h>
6
6
#include < faabric/proto/faabric.pb.h>
7
7
#include < faabric/scheduler/InMemoryMessageQueue.h>
8
8
#include < faabric/transport/PointToPointBroker.h>
@@ -26,10 +26,9 @@ namespace faabric::mpi {
26
26
// -----------------------------------
27
27
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
28
28
// as the broker already has mocking capabilities
29
- std::vector<std::shared_ptr<MPIMessage> > getMpiMockedMessages (int sendRank);
29
+ std::vector<MpiMessage > getMpiMockedMessages (int sendRank);
30
30
31
- typedef faabric::util::FixedCapacityQueue<std::shared_ptr<MPIMessage>>
32
- InMemoryMpiQueue;
31
+ typedef faabric::util::SpinLockQueue<MpiMessage> InMemoryMpiQueue;
33
32
34
33
class MpiWorld
35
34
{
@@ -73,36 +72,36 @@ class MpiWorld
73
72
const uint8_t * buffer,
74
73
faabric_datatype_t * dataType,
75
74
int count,
76
- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
75
+ MpiMessageType messageType = MpiMessageType ::NORMAL);
77
76
78
77
int isend (int sendRank,
79
78
int recvRank,
80
79
const uint8_t * buffer,
81
80
faabric_datatype_t * dataType,
82
81
int count,
83
- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
82
+ MpiMessageType messageType = MpiMessageType ::NORMAL);
84
83
85
84
void broadcast (int rootRank,
86
85
int thisRank,
87
86
uint8_t * buffer,
88
87
faabric_datatype_t * dataType,
89
88
int count,
90
- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
89
+ MpiMessageType messageType = MpiMessageType ::NORMAL);
91
90
92
91
void recv (int sendRank,
93
92
int recvRank,
94
93
uint8_t * buffer,
95
94
faabric_datatype_t * dataType,
96
95
int count,
97
96
MPI_Status* status,
98
- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
97
+ MpiMessageType messageType = MpiMessageType ::NORMAL);
99
98
100
99
int irecv (int sendRank,
101
100
int recvRank,
102
101
uint8_t * buffer,
103
102
faabric_datatype_t * dataType,
104
103
int count,
105
- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
104
+ MpiMessageType messageType = MpiMessageType ::NORMAL);
106
105
107
106
void awaitAsyncRequest (int requestId);
108
107
@@ -185,8 +184,6 @@ class MpiWorld
185
184
186
185
std::shared_ptr<InMemoryMpiQueue> getLocalQueue (int sendRank, int recvRank);
187
186
188
- long getLocalQueueSize (int sendRank, int recvRank);
189
-
190
187
void overrideHost (const std::string& newHost);
191
188
192
189
double getWTime ();
@@ -240,29 +237,36 @@ class MpiWorld
240
237
void sendRemoteMpiMessage (std::string dstHost,
241
238
int sendRank,
242
239
int recvRank,
243
- const std::shared_ptr<MPIMessage> & msg);
240
+ const MpiMessage & msg);
244
241
245
- std::shared_ptr<MPIMessage> recvRemoteMpiMessage (int sendRank,
246
- int recvRank);
242
+ MpiMessage recvRemoteMpiMessage (int sendRank, int recvRank);
247
243
248
244
// Support for asyncrhonous communications
249
245
std::shared_ptr<MpiMessageBuffer> getUnackedMessageBuffer (int sendRank,
250
246
int recvRank);
251
247
252
- std::shared_ptr<MPIMessage> recvBatchReturnLast (int sendRank,
253
- int recvRank,
254
- int batchSize = 0 );
248
+ MpiMessage recvBatchReturnLast (int sendRank,
249
+ int recvRank,
250
+ int batchSize = 0 );
255
251
256
252
/* Helper methods */
257
253
258
254
void checkRanksRange (int sendRank, int recvRank);
259
255
260
256
// Abstraction of the bulk of the recv work, shared among various functions
261
- void doRecv (std::shared_ptr<MPIMessage>& m,
257
+ void doRecv (const MpiMessage& m,
258
+ uint8_t * buffer,
259
+ faabric_datatype_t * dataType,
260
+ int count,
261
+ MPI_Status* status,
262
+ MpiMessageType messageType = MpiMessageType::NORMAL);
263
+
264
+ // Abstraction of the bulk of the recv work, shared among various functions
265
+ void doRecv (std::unique_ptr<MpiMessage> m,
262
266
uint8_t * buffer,
263
267
faabric_datatype_t * dataType,
264
268
int count,
265
269
MPI_Status* status,
266
- MPIMessage::MPIMessageType messageType = MPIMessage ::NORMAL);
270
+ MpiMessageType messageType = MpiMessageType ::NORMAL);
267
271
};
268
272
}
0 commit comments