@@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
103
103
ompi_comm_cid_context_t * cid_context ;
104
104
int * tmpbuf ;
105
105
106
- /* for intercomm allreduce */
107
- int * rcounts ;
108
- int * rdisps ;
109
-
110
106
/* for group allreduce */
111
107
int peers_comm [3 ];
112
108
};
@@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
121
117
static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t * context )
122
118
{
123
119
free (context -> tmpbuf );
124
- free (context -> rcounts );
125
- free (context -> rdisps );
126
120
}
127
121
128
122
OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t , opal_object_t ,
@@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
602
596
/* Non-blocking version of ompi_comm_allreduce_inter */
603
597
static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t * request );
604
598
static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t * request );
605
- static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t * request );
599
+ static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t * request );
606
600
607
601
static int ompi_comm_allreduce_inter_nb (int * inbuf , int * outbuf ,
608
602
int count , struct ompi_op_t * op ,
@@ -637,17 +631,16 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
637
631
local_rank = ompi_comm_rank (intercomm );
638
632
639
633
context -> tmpbuf = (int * ) calloc (count , sizeof (int ));
640
- context -> rdisps = (int * ) calloc (rsize , sizeof (int ));
641
- context -> rcounts = (int * ) calloc (rsize , sizeof (int ));
642
- if (OPAL_UNLIKELY (NULL == context -> tmpbuf || NULL == context -> rdisps || NULL == context -> rcounts )) {
634
+ if (OPAL_UNLIKELY (NULL == context -> tmpbuf )) {
643
635
ompi_comm_request_return (request );
644
636
return OMPI_ERR_OUT_OF_RESOURCE ;
645
637
}
646
638
647
639
/* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group
648
640
* and vise-versa. */
649
- rc = intercomm -> c_coll .coll_iallreduce (inbuf , context -> tmpbuf , count , MPI_INT , op , intercomm ,
650
- & subreq , intercomm -> c_coll .coll_iallreduce_module );
641
+ rc = intercomm -> c_local_comm -> c_coll .coll_iallreduce (inbuf , context -> tmpbuf , count , MPI_INT , op ,
642
+ intercomm -> c_local_comm , & subreq ,
643
+ intercomm -> c_local_comm -> c_coll .coll_iallreduce_module );
651
644
if (OPAL_UNLIKELY (OMPI_SUCCESS != rc )) {
652
645
ompi_comm_request_return (request );
653
646
return rc ;
@@ -656,7 +649,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
656
649
if (0 == local_rank ) {
657
650
ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_leader_exchange , & subreq , 1 );
658
651
} else {
659
- ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_allgather , & subreq , 1 );
652
+ ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_bcast , & subreq , 1 );
660
653
}
661
654
662
655
ompi_comm_request_start (request );
@@ -696,33 +689,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request
696
689
697
690
ompi_op_reduce (context -> op , context -> tmpbuf , context -> outbuf , context -> count , MPI_INT );
698
691
699
- return ompi_comm_allreduce_inter_allgather (request );
692
+ return ompi_comm_allreduce_inter_bcast (request );
700
693
}
701
694
702
695
703
- static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t * request )
696
+ static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t * request )
704
697
{
705
698
ompi_comm_allreduce_context_t * context = (ompi_comm_allreduce_context_t * ) request -> context ;
706
- ompi_communicator_t * intercomm = context -> cid_context -> comm ;
699
+ ompi_communicator_t * comm = context -> cid_context -> comm -> c_local_comm ;
707
700
ompi_request_t * subreq ;
708
701
int scount = 0 , rc ;
709
702
710
- /* distribute the overall result to all processes in the other group.
711
- Instead of using bcast, we are using here allgatherv, to avoid the
712
- possible deadlock. Else, we need an algorithm to determine,
713
- which group sends first in the inter-bcast and which receives
714
- the result first.
715
- */
716
-
717
- if (0 != ompi_comm_rank (intercomm )) {
718
- context -> rcounts [0 ] = context -> count ;
719
- } else {
720
- scount = context -> count ;
721
- }
722
-
723
- rc = intercomm -> c_coll .coll_iallgatherv (context -> outbuf , scount , MPI_INT , context -> outbuf ,
724
- context -> rcounts , context -> rdisps , MPI_INT , intercomm ,
725
- & subreq , intercomm -> c_coll .coll_iallgatherv_module );
703
+ /* both roots have the same result. broadcast to the local group */
704
+ rc = comm -> c_coll .coll_ibcast (context -> outbuf , context -> count , MPI_INT , 0 , comm ,
705
+ & subreq , comm -> c_coll .coll_ibcast_module );
726
706
if (OMPI_SUCCESS != rc ) {
727
707
return rc ;
728
708
}
0 commit comments