Skip to content

Commit 9e30ac0

Browse files
author
rhc54
authored
Merge pull request #2248 from ggouaillardet/topic/v1.10/ialltoall_in_place
v1.10: add support for MPI_IN_PLACE in MPI_Ialltoall
2 parents 118ad7c + e194e68 commit 9e30ac0

File tree

8 files changed

+413
-89
lines changed

8 files changed

+413
-89
lines changed

ompi/mca/coll/libnbc/nbc_iallreduce.c

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,7 @@ static inline int allred_sched_linear(int rank, int rsize, void *sendbuf, void *
541541
} else {
542542
res = NBC_Sched_recv ((void *)(-gap), true, count, datatype, 0, schedule);
543543
}
544-
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
545-
return res;
546-
}
544+
if (NBC_OK != res) { printf("Error in NBC_Sched_recv() (%i)\n", res); return res; }
547545

548546
if (0 == rank) {
549547
char *rbuf, *lbuf, *buf;
@@ -574,9 +572,7 @@ static inline int allred_sched_linear(int rank, int rsize, void *sendbuf, void *
574572
if (NBC_OK != res) { printf("Error in NBC_Sched_barrier() (%i)\n", res); return res; }
575573

576574
res = NBC_Sched_op (lbuf, tmplbuf, rbuf, tmprbuf, count, datatype, op, schedule);
577-
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
578-
return res;
579-
}
575+
if (NBC_OK != res) { printf("Error in NBC_Sched_op() (%i)\n", res); return res; }
580576

581577
res = NBC_Sched_barrier(schedule);
582578
if (NBC_OK != res) { printf("Error in NBC_Sched_barrier() (%i)\n", res); return res; }

ompi/mca/coll/libnbc/nbc_ialltoall.c

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
* Copyright (c) 2013 Los Alamos National Security, LLC. All rights
88
* reserved.
99
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
10+
* Copyright (c) 2016 Research Organization for Information Science
11+
* and Technology (RIST). All rights reserved.
1012
*
1113
* Author(s): Torsten Hoefler <[email protected]>
1214
*
@@ -16,6 +18,7 @@
1618
static inline int a2a_sched_linear(int rank, int p, MPI_Aint sndext, MPI_Aint rcvext, NBC_Schedule *schedule, void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm);
1719
static inline int a2a_sched_pairwise(int rank, int p, MPI_Aint sndext, MPI_Aint rcvext, NBC_Schedule *schedule, void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm);
1820
static inline int a2a_sched_diss(int rank, int p, MPI_Aint sndext, MPI_Aint rcvext, NBC_Schedule* schedule, void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm, NBC_Handle *handle);
21+
static inline int a2a_sched_inplace(int rank, int p, NBC_Schedule* schedule, void* buf, int count, MPI_Datatype type, MPI_Aint ext, ptrdiff_t gap, MPI_Comm comm);
1922

2023
#ifdef NBC_CACHE_SCHEDULE
2124
/* tree comparison function for schedule cache */
@@ -48,10 +51,11 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
4851
NBC_Alltoall_args *args, *found, search;
4952
#endif
5053
char *rbuf, *sbuf, inplace;
51-
enum {NBC_A2A_LINEAR, NBC_A2A_PAIRWISE, NBC_A2A_DISS} alg;
54+
enum {NBC_A2A_LINEAR, NBC_A2A_PAIRWISE, NBC_A2A_DISS, NBC_A2A_INPLACE} alg;
5255
NBC_Handle *handle;
5356
ompi_coll_libnbc_request_t **coll_req = (ompi_coll_libnbc_request_t**) request;
5457
ompi_coll_libnbc_module_t *libnbc_module = (ompi_coll_libnbc_module_t*) module;
58+
ptrdiff_t span, gap;
5559

5660
NBC_IN_PLACE(sendbuf, recvbuf, inplace);
5761

@@ -72,7 +76,9 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
7276
/* algorithm selection */
7377
a2asize = sndsize*sendcount*p;
7478
/* this number is optimized for TCP on odin.cs.indiana.edu */
75-
if((p <= 8) && ((a2asize < 1<<17) || (sndsize*sendcount < 1<<12))) {
79+
if (inplace) {
80+
alg = NBC_A2A_INPLACE;
81+
} else if((p <= 8) && ((a2asize < 1<<17) || (sndsize*sendcount < 1<<12))) {
7682
/* just send as fast as we can if we have less than 8 peers, if the
7783
* total communicated size is smaller than 1<<17 *and* if we don't
7884
* have eager messages (msgsize < 1<<13) */
@@ -92,7 +98,11 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
9298
}
9399

94100
/* allocate temp buffer if we need one */
95-
if(alg == NBC_A2A_DISS) {
101+
if (alg == NBC_A2A_INPLACE) {
102+
span = opal_datatype_span(&recvtype->super, recvcount, &gap);
103+
handle->tmpbuf = malloc(span);
104+
if (OPAL_UNLIKELY(NULL == handle->tmpbuf)) { printf("Error in malloc()\n"); return NBC_OOR; }
105+
} else if (alg == NBC_A2A_DISS) {
96106
/* only A2A_DISS needs buffers */
97107
if(NBC_Type_intrinsic(sendtype)) {
98108
datasize = sndext*sendcount;
@@ -146,13 +156,16 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
146156
#endif
147157
/* not found - generate new schedule */
148158
schedule = (NBC_Schedule*)malloc(sizeof(NBC_Schedule));
149-
if (NULL == schedule) { printf("Error in malloc()\n"); return res; }
159+
if (NULL == schedule) { printf("Error in malloc()\n"); return NBC_OOR; }
150160

151161
res = NBC_Sched_create(schedule);
152162
if(res != NBC_OK) { printf("Error in NBC_Sched_create (%i)\n", res); return res; }
153163

154164
switch(alg) {
155-
case NBC_A2A_LINEAR:
165+
case NBC_A2A_INPLACE:
166+
res = a2a_sched_inplace(rank, p, schedule, recvbuf, recvcount, recvtype, rcvext, gap, comm);
167+
break;
168+
case NBC_A2A_LINEAR:
156169
res = a2a_sched_linear(rank, p, sndext, rcvext, schedule, sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm);
157170
break;
158171
case NBC_A2A_DISS:
@@ -224,7 +237,7 @@ int ompi_coll_libnbc_ialltoall_inter (void* sendbuf, int sendcount, MPI_Datatype
224237
if (MPI_SUCCESS != res) { printf("MPI Error in MPI_Type_extent() (%i)\n", res); return res; }
225238

226239
schedule = (NBC_Schedule*)malloc(sizeof(NBC_Schedule));
227-
if (NULL == schedule) { printf("Error in malloc() (%i)\n", res); return res; }
240+
if (NULL == schedule) { printf("Error in malloc()\n"); return NBC_OOR; }
228241

229242
handle->tmpbuf=NULL;
230243

@@ -378,3 +391,48 @@ static inline int a2a_sched_diss(int rank, int p, MPI_Aint sndext, MPI_Aint rcve
378391
return NBC_OK;
379392
}
380393

394+
static inline int a2a_sched_inplace(int rank, int p, NBC_Schedule* schedule, void* buf, int count,
395+
MPI_Datatype type, MPI_Aint ext, ptrdiff_t gap, MPI_Comm comm) {
396+
int res;
397+
398+
for (int i = 1 ; i < (p+1)/2 ; i++) {
399+
int speer = (rank + i) % p;
400+
int rpeer = (rank + p - i) % p;
401+
char *sbuf = (char *) buf + speer * count * ext;
402+
char *rbuf = (char *) buf + rpeer * count * ext;
403+
404+
res = NBC_Sched_copy (rbuf, false, count, type,
405+
(void *)(-gap), true, count, type,
406+
schedule);
407+
if (NBC_OK != res) { printf("Error in NBC_Sched_copy() (%i)\n", res); return res; }
408+
res = NBC_Sched_barrier(schedule);
409+
if (NBC_OK != res) { printf("Error in NBC_Sched_barr() (%i)\n", res); return res; }
410+
res = NBC_Sched_send (sbuf, false , count, type, speer, schedule);
411+
if (NBC_OK != res) { printf("Error in NBC_Sched_send() (%i)\n", res); return res; }
412+
res = NBC_Sched_recv (rbuf, false , count, type, rpeer, schedule);
413+
if (NBC_OK != res) { printf("Error in NBC_Sched_recv() (%i)\n", res); return res; }
414+
res = NBC_Sched_barrier(schedule);
415+
if (NBC_OK != res) { printf("Error in NBC_Sched_barr() (%i)\n", res); return res; }
416+
res = NBC_Sched_send ((void *)(-gap), true, count, type, rpeer, schedule);
417+
if (NBC_OK != res) { printf("Error in NBC_Sched_send() (%i)\n", res); return res; }
418+
res = NBC_Sched_recv (sbuf, false, count, type, speer, schedule);
419+
if (NBC_OK != res) { printf("Error in NBC_Sched_recv() (%i)\n", res); return res; }
420+
res = NBC_Sched_barrier(schedule);
421+
if (NBC_OK != res) { printf("Error in NBC_Sched_barr() (%i)\n", res); return res; }
422+
}
423+
if (0 == (p%2)) {
424+
int peer = (rank + p/2) % p;
425+
426+
char *tbuf = (char *) buf + peer * count * ext;
427+
res = NBC_Sched_copy (tbuf, false, count, type, (void *)(-gap), true, count, type, schedule);
428+
if (NBC_OK != res) { printf("Error in NBC_Sched_copy() (%i)\n", res); return res; }
429+
res = NBC_Sched_barrier(schedule);
430+
if (NBC_OK != res) { printf("Error in NBC_Sched_barr() (%i)\n", res); return res; }
431+
res = NBC_Sched_send ((void *)(-gap), true , count, type, peer, schedule);
432+
if (NBC_OK != res) { printf("Error in NBC_Sched_send() (%i)\n", res); return res; }
433+
res = NBC_Sched_recv (tbuf, false , count, type, peer, schedule);
434+
if (NBC_OK != res) { printf("Error in NBC_Sched_recv() (%i)\n", res); return res; }
435+
}
436+
437+
return NBC_OK;
438+
}

0 commit comments

Comments
 (0)