17
17
* Copyright (c) 2017-2022 IBM Corporation. All rights reserved.
18
18
* Copyright (c) 2021 Amazon.com, Inc. or its affiliates. All Rights
19
19
* reserved.
20
+ * Copyright (c) 2022 BULL S.A.S. All rights reserved.
20
21
* $COPYRIGHT$
21
22
*
22
23
* Additional copyrights may follow
@@ -222,8 +223,8 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
222
223
struct ompi_communicator_t * comm ,
223
224
mca_coll_base_module_t * module )
224
225
{
225
- int i , k , line = -1 , rank , size , err = 0 ;
226
- int sendto , recvfrom , distance , * displs = NULL , * blen = NULL ;
226
+ int i , line = -1 , rank , size , err = 0 ;
227
+ int sendto , recvfrom , distance , * displs = NULL ;
227
228
char * tmpbuf = NULL , * tmpbuf_free = NULL ;
228
229
ptrdiff_t sext , rext , span , gap = 0 ;
229
230
struct ompi_datatype_t * new_ddt ;
@@ -245,31 +246,31 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
245
246
err = ompi_datatype_type_extent (rdtype , & rext );
246
247
if (err != MPI_SUCCESS ) { line = __LINE__ ; goto err_hndl ; }
247
248
248
- span = opal_datatype_span (& sdtype -> super , (int64_t )size * scount , & gap );
249
-
250
- displs = (int * ) malloc (size * sizeof (int ));
251
- if (displs == NULL ) { line = __LINE__ ; err = -1 ; goto err_hndl ; }
252
- blen = (int * ) malloc (size * sizeof (int ));
253
- if (blen == NULL ) { line = __LINE__ ; err = -1 ; goto err_hndl ; }
249
+ span = opal_datatype_span (& rdtype -> super , (int64_t )size * rcount , & gap );
254
250
255
251
/* tmp buffer allocation for message data */
256
252
tmpbuf_free = (char * )malloc (span );
257
253
if (tmpbuf_free == NULL ) { line = __LINE__ ; err = -1 ; goto err_hndl ; }
258
254
tmpbuf = tmpbuf_free - gap ;
259
255
260
256
/* Step 1 - local rotation - shift up by rank */
261
- err = ompi_datatype_copy_content_same_ddt (sdtype ,
262
- (int32_t ) ((ptrdiff_t )(size - rank ) * (ptrdiff_t )scount ),
263
- tmpbuf ,
264
- ((char * ) sbuf ) + (ptrdiff_t )rank * (ptrdiff_t )scount * sext );
257
+ err = ompi_datatype_sndrcv (sbuf + ((ptrdiff_t ) rank * scount * sext ),
258
+ (int32_t ) (size - rank ) * scount ,
259
+ sdtype ,
260
+ tmpbuf ,
261
+ (int32_t ) (size - rank ) * rcount ,
262
+ rdtype );
265
263
if (err < 0 ) {
266
264
line = __LINE__ ; err = -1 ; goto err_hndl ;
267
265
}
268
266
269
267
if (rank != 0 ) {
270
- err = ompi_datatype_copy_content_same_ddt (sdtype , (ptrdiff_t )rank * (ptrdiff_t )scount ,
271
- tmpbuf + (ptrdiff_t )(size - rank ) * (ptrdiff_t )scount * sext ,
272
- (char * ) sbuf );
268
+ err = ompi_datatype_sndrcv (sbuf ,
269
+ (int32_t ) rank * scount ,
270
+ sdtype ,
271
+ tmpbuf + ((ptrdiff_t ) (size - rank ) * rcount * rext ),
272
+ (int32_t ) rank * rcount ,
273
+ rdtype );
273
274
if (err < 0 ) {
274
275
line = __LINE__ ; err = -1 ; goto err_hndl ;
275
276
}
@@ -280,19 +281,19 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
280
281
281
282
sendto = (rank + distance ) % size ;
282
283
recvfrom = (rank - distance + size ) % size ;
283
- k = 0 ;
284
-
285
- /* create indexed datatype */
286
- for ( i = 1 ; i < size ; i ++ ) {
287
- if (( i & distance ) == distance ) {
288
- displs [ k ] = ( ptrdiff_t ) i * ( ptrdiff_t ) scount ;
289
- blen [ k ] = scount ;
290
- k ++ ;
284
+
285
+ new_ddt = ompi_datatype_create (( 1 + size / distance ) * ( 2 + rdtype -> super . desc . used ));
286
+
287
+ /* Create datatype describing data sent/received */
288
+ for ( i = distance ; i < size ; i += 2 * distance ) {
289
+ int nblocks = distance ;
290
+ if ( i + distance >= size ) {
291
+ nblocks = size - i ;
291
292
}
293
+ ompi_datatype_add (new_ddt , rdtype , rcount * nblocks ,
294
+ i * rcount * rext , rext );
292
295
}
293
- /* Set indexes and displacements */
294
- err = ompi_datatype_create_indexed (k , blen , displs , sdtype , & new_ddt );
295
- if (err != MPI_SUCCESS ) { line = __LINE__ ; goto err_hndl ; }
296
+
296
297
/* Commit the new datatype */
297
298
err = ompi_datatype_commit (& new_ddt );
298
299
if (err != MPI_SUCCESS ) { line = __LINE__ ; goto err_hndl ; }
@@ -324,19 +325,16 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
324
325
}
325
326
326
327
/* Step 4 - clean up */
327
- if (tmpbuf != NULL ) free (tmpbuf_free );
328
- if (displs != NULL ) free (displs );
329
- if (blen != NULL ) free (blen );
328
+ if (tmpbuf_free != NULL ) free (tmpbuf_free );
330
329
return OMPI_SUCCESS ;
331
330
332
331
err_hndl :
333
332
OPAL_OUTPUT ((ompi_coll_base_framework .framework_output ,
334
333
"%s:%4d\tError occurred %d, rank %2d" , __FILE__ , line , err ,
335
334
rank ));
336
335
(void )line ; // silence compiler warning
337
- if (tmpbuf != NULL ) free (tmpbuf_free );
336
+ if (tmpbuf_free != NULL ) free (tmpbuf_free );
338
337
if (displs != NULL ) free (displs );
339
- if (blen != NULL ) free (blen );
340
338
return err ;
341
339
}
342
340
0 commit comments