Skip to content

Commit 7cb3ec0

Browse files
authored
Merge pull request #10488 from FlorentGermain-Bull/fix_coll_base_alltoall_bruck
Coll/base: FIX bruck when sdtype and rdtype are different
2 parents 8fbfe4c + 99ca572 commit 7cb3ec0

File tree

1 file changed

+29
-31
lines changed

1 file changed

+29
-31
lines changed

ompi/mca/coll/base/coll_base_alltoall.c

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* Copyright (c) 2017-2022 IBM Corporation. All rights reserved.
1818
* Copyright (c) 2021 Amazon.com, Inc. or its affiliates. All Rights
1919
* reserved.
20+
* Copyright (c) 2022 BULL S.A.S. All rights reserved.
2021
* $COPYRIGHT$
2122
*
2223
* Additional copyrights may follow
@@ -222,8 +223,8 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
222223
struct ompi_communicator_t *comm,
223224
mca_coll_base_module_t *module)
224225
{
225-
int i, k, line = -1, rank, size, err = 0;
226-
int sendto, recvfrom, distance, *displs = NULL, *blen = NULL;
226+
int i, line = -1, rank, size, err = 0;
227+
int sendto, recvfrom, distance, *displs = NULL;
227228
char *tmpbuf = NULL, *tmpbuf_free = NULL;
228229
ptrdiff_t sext, rext, span, gap = 0;
229230
struct ompi_datatype_t *new_ddt;
@@ -245,31 +246,31 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
245246
err = ompi_datatype_type_extent (rdtype, &rext);
246247
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
247248

248-
span = opal_datatype_span(&sdtype->super, (int64_t)size * scount, &gap);
249-
250-
displs = (int *) malloc(size * sizeof(int));
251-
if (displs == NULL) { line = __LINE__; err = -1; goto err_hndl; }
252-
blen = (int *) malloc(size * sizeof(int));
253-
if (blen == NULL) { line = __LINE__; err = -1; goto err_hndl; }
249+
span = opal_datatype_span(&rdtype->super, (int64_t)size * rcount, &gap);
254250

255251
/* tmp buffer allocation for message data */
256252
tmpbuf_free = (char *)malloc(span);
257253
if (tmpbuf_free == NULL) { line = __LINE__; err = -1; goto err_hndl; }
258254
tmpbuf = tmpbuf_free - gap;
259255

260256
/* Step 1 - local rotation - shift up by rank */
261-
err = ompi_datatype_copy_content_same_ddt (sdtype,
262-
(int32_t) ((ptrdiff_t)(size - rank) * (ptrdiff_t)scount),
263-
tmpbuf,
264-
((char*) sbuf) + (ptrdiff_t)rank * (ptrdiff_t)scount * sext);
257+
err = ompi_datatype_sndrcv (sbuf + ((ptrdiff_t) rank * scount * sext),
258+
(int32_t) (size - rank) * scount,
259+
sdtype,
260+
tmpbuf,
261+
(int32_t) (size - rank) * rcount,
262+
rdtype);
265263
if (err<0) {
266264
line = __LINE__; err = -1; goto err_hndl;
267265
}
268266

269267
if (rank != 0) {
270-
err = ompi_datatype_copy_content_same_ddt (sdtype, (ptrdiff_t)rank * (ptrdiff_t)scount,
271-
tmpbuf + (ptrdiff_t)(size - rank) * (ptrdiff_t)scount* sext,
272-
(char*) sbuf);
268+
err = ompi_datatype_sndrcv (sbuf,
269+
(int32_t) rank * scount,
270+
sdtype,
271+
tmpbuf + ((ptrdiff_t) (size - rank) * rcount * rext),
272+
(int32_t) rank * rcount,
273+
rdtype);
273274
if (err<0) {
274275
line = __LINE__; err = -1; goto err_hndl;
275276
}
@@ -280,19 +281,19 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
280281

281282
sendto = (rank + distance) % size;
282283
recvfrom = (rank - distance + size) % size;
283-
k = 0;
284-
285-
/* create indexed datatype */
286-
for (i = 1; i < size; i++) {
287-
if (( i & distance) == distance) {
288-
displs[k] = (ptrdiff_t)i * (ptrdiff_t)scount;
289-
blen[k] = scount;
290-
k++;
284+
285+
new_ddt = ompi_datatype_create((1 + size/distance) * (2 + rdtype->super.desc.used));
286+
287+
/* Create datatype describing data sent/received */
288+
for (i = distance; i < size; i += 2*distance) {
289+
int nblocks = distance;
290+
if (i + distance >= size) {
291+
nblocks = size - i;
291292
}
293+
ompi_datatype_add(new_ddt, rdtype, rcount * nblocks,
294+
i * rcount * rext, rext);
292295
}
293-
/* Set indexes and displacements */
294-
err = ompi_datatype_create_indexed(k, blen, displs, sdtype, &new_ddt);
295-
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
296+
296297
/* Commit the new datatype */
297298
err = ompi_datatype_commit(&new_ddt);
298299
if (err != MPI_SUCCESS) { line = __LINE__; goto err_hndl; }
@@ -324,19 +325,16 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
324325
}
325326

326327
/* Step 4 - clean up */
327-
if (tmpbuf != NULL) free(tmpbuf_free);
328-
if (displs != NULL) free(displs);
329-
if (blen != NULL) free(blen);
328+
if (tmpbuf_free != NULL) free(tmpbuf_free);
330329
return OMPI_SUCCESS;
331330

332331
err_hndl:
333332
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
334333
"%s:%4d\tError occurred %d, rank %2d", __FILE__, line, err,
335334
rank));
336335
(void)line; // silence compiler warning
337-
if (tmpbuf != NULL) free(tmpbuf_free);
336+
if (tmpbuf_free != NULL) free(tmpbuf_free);
338337
if (displs != NULL) free(displs);
339-
if (blen != NULL) free(blen);
340338
return err;
341339
}
342340

0 commit comments

Comments
 (0)