Skip to content

Commit 79d942b

Browse files
committed
comm/cid: use ibcast to distribute result in intercomm case
This commit updates the intercomm allgather to do a local comm bcast as the final step. This should resolve a hang seen in intercomm tests. Signed-off-by: Nathan Hjelm <[email protected]>
1 parent a6d515b commit 79d942b

File tree

1 file changed

+17
-35
lines changed

1 file changed

+17
-35
lines changed

ompi/communicator/comm_cid.c

+17-35
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
103103
ompi_comm_cid_context_t *cid_context;
104104
int *tmpbuf;
105105

106-
/* for intercomm allreduce */
107-
int *rcounts;
108-
int *rdisps;
109-
110106
/* for group allreduce */
111107
int peers_comm[3];
112108
};
@@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
121117
static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t *context)
122118
{
123119
free (context->tmpbuf);
124-
free (context->rcounts);
125-
free (context->rdisps);
126120
}
127121

128122
OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t, opal_object_t,
@@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
602596
/* Non-blocking version of ompi_comm_allreduce_inter */
603597
static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t *request);
604598
static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request);
605-
static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request);
599+
static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request);
606600

607601
static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
608602
int count, struct ompi_op_t *op,
@@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
636630
rsize = ompi_comm_remote_size (intercomm);
637631
local_rank = ompi_comm_rank (intercomm);
638632

639-
context->tmpbuf = (int *) calloc (count, sizeof(int));
640-
context->rdisps = (int *) calloc (rsize, sizeof(int));
641-
context->rcounts = (int *) calloc (rsize, sizeof(int));
642-
if (OPAL_UNLIKELY (NULL == context->tmpbuf || NULL == context->rdisps || NULL == context->rcounts)) {
643-
ompi_comm_request_return (request);
644-
return OMPI_ERR_OUT_OF_RESOURCE;
633+
if (0 == local_rank) {
634+
context->tmpbuf = (int *) calloc (count, sizeof(int));
635+
if (OPAL_UNLIKELY (NULL == context->tmpbuf)) {
636+
ompi_comm_request_return (request);
637+
return OMPI_ERR_OUT_OF_RESOURCE;
638+
}
645639
}
646640

647641
/* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group
648642
* and vise-versa. */
649-
rc = intercomm->c_coll.coll_iallreduce (inbuf, context->tmpbuf, count, MPI_INT, op, intercomm,
650-
&subreq, intercomm->c_coll.coll_iallreduce_module);
643+
rc = intercomm->c_local_comm->c_coll.coll_ireduce (inbuf, context->tmpbuf, count, MPI_INT, op, 0
644+
intercomm->c_local_comm, &subreq,
645+
intercomm->c_local_comm->c_coll.coll_ireduce_module);
651646
if (OPAL_UNLIKELY(OMPI_SUCCESS != rc)) {
652647
ompi_comm_request_return (request);
653648
return rc;
@@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
656651
if (0 == local_rank) {
657652
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_leader_exchange, &subreq, 1);
658653
} else {
659-
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_allgather, &subreq, 1);
654+
ompi_comm_request_schedule_append (request, ompi_comm_allreduce_inter_bcast, &subreq, 1);
660655
}
661656

662657
ompi_comm_request_start (request);
@@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request
696691

697692
ompi_op_reduce (context->op, context->tmpbuf, context->outbuf, context->count, MPI_INT);
698693

699-
return ompi_comm_allreduce_inter_allgather (request);
694+
return ompi_comm_allreduce_inter_bcast (request);
700695
}
701696

702697

703-
static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t *request)
698+
static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t *request)
704699
{
705700
ompi_comm_allreduce_context_t *context = (ompi_comm_allreduce_context_t *) request->context;
706-
ompi_communicator_t *intercomm = context->cid_context->comm;
701+
ompi_communicator_t *comm = context->cid_context->comm->c_local_comm;
707702
ompi_request_t *subreq;
708703
int scount = 0, rc;
709704

710-
/* distribute the overall result to all processes in the other group.
711-
Instead of using bcast, we are using here allgatherv, to avoid the
712-
possible deadlock. Else, we need an algorithm to determine,
713-
which group sends first in the inter-bcast and which receives
714-
the result first.
715-
*/
716-
717-
if (0 != ompi_comm_rank (intercomm)) {
718-
context->rcounts[0] = context->count;
719-
} else {
720-
scount = context->count;
721-
}
722-
723-
rc = intercomm->c_coll.coll_iallgatherv (context->outbuf, scount, MPI_INT, context->outbuf,
724-
context->rcounts, context->rdisps, MPI_INT, intercomm,
725-
&subreq, intercomm->c_coll.coll_iallgatherv_module);
705+
/* both roots have the same result. broadcast to the local group */
706+
rc = comm->c_coll.coll_ibcast (context->outbuf, context->count, MPI_INT, 0, comm,
707+
&subreq, comm->c_coll.coll_ibcast_module);
726708
if (OMPI_SUCCESS != rc) {
727709
return rc;
728710
}

0 commit comments

Comments
 (0)