7
7
* Copyright (c) 2013 Los Alamos National Security, LLC. All rights
8
8
* reserved.
9
9
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
10
+ * Copyright (c) 2016 Research Organization for Information Science
11
+ * and Technology (RIST). All rights reserved.
10
12
*
11
13
* Author(s): Torsten Hoefler <[email protected] >
12
14
*
16
18
static inline int a2a_sched_linear (int rank , int p , MPI_Aint sndext , MPI_Aint rcvext , NBC_Schedule * schedule , void * sendbuf , int sendcount , MPI_Datatype sendtype , void * recvbuf , int recvcount , MPI_Datatype recvtype , MPI_Comm comm );
17
19
static inline int a2a_sched_pairwise (int rank , int p , MPI_Aint sndext , MPI_Aint rcvext , NBC_Schedule * schedule , void * sendbuf , int sendcount , MPI_Datatype sendtype , void * recvbuf , int recvcount , MPI_Datatype recvtype , MPI_Comm comm );
18
20
static inline int a2a_sched_diss (int rank , int p , MPI_Aint sndext , MPI_Aint rcvext , NBC_Schedule * schedule , void * sendbuf , int sendcount , MPI_Datatype sendtype , void * recvbuf , int recvcount , MPI_Datatype recvtype , MPI_Comm comm , NBC_Handle * handle );
21
+ static inline int a2a_sched_inplace (int rank , int p , NBC_Schedule * schedule , void * buf , int count , MPI_Datatype type , MPI_Aint ext , ptrdiff_t gap , MPI_Comm comm );
19
22
20
23
#ifdef NBC_CACHE_SCHEDULE
21
24
/* tree comparison function for schedule cache */
@@ -48,10 +51,11 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
48
51
NBC_Alltoall_args * args , * found , search ;
49
52
#endif
50
53
char * rbuf , * sbuf , inplace ;
51
- enum {NBC_A2A_LINEAR , NBC_A2A_PAIRWISE , NBC_A2A_DISS } alg ;
54
+ enum {NBC_A2A_LINEAR , NBC_A2A_PAIRWISE , NBC_A2A_DISS , NBC_A2A_INPLACE } alg ;
52
55
NBC_Handle * handle ;
53
56
ompi_coll_libnbc_request_t * * coll_req = (ompi_coll_libnbc_request_t * * ) request ;
54
57
ompi_coll_libnbc_module_t * libnbc_module = (ompi_coll_libnbc_module_t * ) module ;
58
+ ptrdiff_t span , gap ;
55
59
56
60
NBC_IN_PLACE (sendbuf , recvbuf , inplace );
57
61
@@ -72,7 +76,9 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
72
76
/* algorithm selection */
73
77
a2asize = sndsize * sendcount * p ;
74
78
/* this number is optimized for TCP on odin.cs.indiana.edu */
75
- if ((p <= 8 ) && ((a2asize < 1 <<17 ) || (sndsize * sendcount < 1 <<12 ))) {
79
+ if (inplace ) {
80
+ alg = NBC_A2A_INPLACE ;
81
+ } else if ((p <= 8 ) && ((a2asize < 1 <<17 ) || (sndsize * sendcount < 1 <<12 ))) {
76
82
/* just send as fast as we can if we have less than 8 peers, if the
77
83
* total communicated size is smaller than 1<<17 *and* if we don't
78
84
* have eager messages (msgsize < 1<<13) */
@@ -92,7 +98,11 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
92
98
}
93
99
94
100
/* allocate temp buffer if we need one */
95
- if (alg == NBC_A2A_DISS ) {
101
+ if (alg == NBC_A2A_INPLACE ) {
102
+ span = opal_datatype_span (& recvtype -> super , recvcount , & gap );
103
+ handle -> tmpbuf = malloc (span );
104
+ if (OPAL_UNLIKELY (NULL == handle -> tmpbuf )) { printf ("Error in malloc()\n" ); return NBC_OOR ; }
105
+ } else if (alg == NBC_A2A_DISS ) {
96
106
/* only A2A_DISS needs buffers */
97
107
if (NBC_Type_intrinsic (sendtype )) {
98
108
datasize = sndext * sendcount ;
@@ -146,13 +156,16 @@ int ompi_coll_libnbc_ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendty
146
156
#endif
147
157
/* not found - generate new schedule */
148
158
schedule = (NBC_Schedule * )malloc (sizeof (NBC_Schedule ));
149
- if (NULL == schedule ) { printf ("Error in malloc()\n" ); return res ; }
159
+ if (NULL == schedule ) { printf ("Error in malloc()\n" ); return NBC_OOR ; }
150
160
151
161
res = NBC_Sched_create (schedule );
152
162
if (res != NBC_OK ) { printf ("Error in NBC_Sched_create (%i)\n" , res ); return res ; }
153
163
154
164
switch (alg ) {
155
- case NBC_A2A_LINEAR :
165
+ case NBC_A2A_INPLACE :
166
+ res = a2a_sched_inplace (rank , p , schedule , recvbuf , recvcount , recvtype , rcvext , gap , comm );
167
+ break ;
168
+ case NBC_A2A_LINEAR :
156
169
res = a2a_sched_linear (rank , p , sndext , rcvext , schedule , sendbuf , sendcount , sendtype , recvbuf , recvcount , recvtype , comm );
157
170
break ;
158
171
case NBC_A2A_DISS :
@@ -224,7 +237,7 @@ int ompi_coll_libnbc_ialltoall_inter (void* sendbuf, int sendcount, MPI_Datatype
224
237
if (MPI_SUCCESS != res ) { printf ("MPI Error in MPI_Type_extent() (%i)\n" , res ); return res ; }
225
238
226
239
schedule = (NBC_Schedule * )malloc (sizeof (NBC_Schedule ));
227
- if (NULL == schedule ) { printf ("Error in malloc() (%i) \n" , res ); return res ; }
240
+ if (NULL == schedule ) { printf ("Error in malloc()\n" ); return NBC_OOR ; }
228
241
229
242
handle -> tmpbuf = NULL ;
230
243
@@ -378,3 +391,48 @@ static inline int a2a_sched_diss(int rank, int p, MPI_Aint sndext, MPI_Aint rcve
378
391
return NBC_OK ;
379
392
}
380
393
394
+ static inline int a2a_sched_inplace (int rank , int p , NBC_Schedule * schedule , void * buf , int count ,
395
+ MPI_Datatype type , MPI_Aint ext , ptrdiff_t gap , MPI_Comm comm ) {
396
+ int res ;
397
+
398
+ for (int i = 1 ; i < (p + 1 )/2 ; i ++ ) {
399
+ int speer = (rank + i ) % p ;
400
+ int rpeer = (rank + p - i ) % p ;
401
+ char * sbuf = (char * ) buf + speer * count * ext ;
402
+ char * rbuf = (char * ) buf + rpeer * count * ext ;
403
+
404
+ res = NBC_Sched_copy (rbuf , false, count , type ,
405
+ (void * )(- gap ), true, count , type ,
406
+ schedule );
407
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_copy() (%i)\n" , res ); return res ; }
408
+ res = NBC_Sched_barrier (schedule );
409
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_barr() (%i)\n" , res ); return res ; }
410
+ res = NBC_Sched_send (sbuf , false , count , type , speer , schedule );
411
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_send() (%i)\n" , res ); return res ; }
412
+ res = NBC_Sched_recv (rbuf , false , count , type , rpeer , schedule );
413
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_recv() (%i)\n" , res ); return res ; }
414
+ res = NBC_Sched_barrier (schedule );
415
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_barr() (%i)\n" , res ); return res ; }
416
+ res = NBC_Sched_send ((void * )(- gap ), true, count , type , rpeer , schedule );
417
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_send() (%i)\n" , res ); return res ; }
418
+ res = NBC_Sched_recv (sbuf , false, count , type , speer , schedule );
419
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_recv() (%i)\n" , res ); return res ; }
420
+ res = NBC_Sched_barrier (schedule );
421
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_barr() (%i)\n" , res ); return res ; }
422
+ }
423
+ if (0 == (p %2 )) {
424
+ int peer = (rank + p /2 ) % p ;
425
+
426
+ char * tbuf = (char * ) buf + peer * count * ext ;
427
+ res = NBC_Sched_copy (tbuf , false, count , type , (void * )(- gap ), true, count , type , schedule );
428
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_copy() (%i)\n" , res ); return res ; }
429
+ res = NBC_Sched_barrier (schedule );
430
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_barr() (%i)\n" , res ); return res ; }
431
+ res = NBC_Sched_send ((void * )(- gap ), true , count , type , peer , schedule );
432
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_send() (%i)\n" , res ); return res ; }
433
+ res = NBC_Sched_recv (tbuf , false , count , type , peer , schedule );
434
+ if (NBC_OK != res ) { printf ("Error in NBC_Sched_recv() (%i)\n" , res ); return res ; }
435
+ }
436
+
437
+ return NBC_OK ;
438
+ }
0 commit comments