@@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
103
103
ompi_comm_cid_context_t * cid_context ;
104
104
int * tmpbuf ;
105
105
106
- /* for intercomm allreduce */
107
- int * rcounts ;
108
- int * rdisps ;
109
-
110
106
/* for group allreduce */
111
107
int peers_comm [3 ];
112
108
};
@@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
121
117
static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t * context )
122
118
{
123
119
free (context -> tmpbuf );
124
- free (context -> rcounts );
125
- free (context -> rdisps );
126
120
}
127
121
128
122
OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t , opal_object_t ,
@@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
602
596
/* Non-blocking version of ompi_comm_allreduce_inter */
603
597
static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t * request );
604
598
static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t * request );
605
- static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t * request );
599
+ static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t * request );
606
600
607
601
static int ompi_comm_allreduce_inter_nb (int * inbuf , int * outbuf ,
608
602
int count , struct ompi_op_t * op ,
@@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
636
630
rsize = ompi_comm_remote_size (intercomm );
637
631
local_rank = ompi_comm_rank (intercomm );
638
632
639
- context -> tmpbuf = ( int * ) calloc ( count , sizeof ( int ));
640
- context -> rdisps = (int * ) calloc (rsize , sizeof (int ));
641
- context -> rcounts = ( int * ) calloc ( rsize , sizeof ( int ));
642
- if ( OPAL_UNLIKELY ( NULL == context -> tmpbuf || NULL == context -> rdisps || NULL == context -> rcounts )) {
643
- ompi_comm_request_return ( request ) ;
644
- return OMPI_ERR_OUT_OF_RESOURCE ;
633
+ if ( 0 == local_rank ) {
634
+ context -> tmpbuf = (int * ) calloc (count , sizeof (int ));
635
+ if ( OPAL_UNLIKELY ( NULL == context -> tmpbuf )) {
636
+ ompi_comm_request_return ( request );
637
+ return OMPI_ERR_OUT_OF_RESOURCE ;
638
+ }
645
639
}
646
640
647
641
/* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group
648
642
* and vise-versa. */
649
- rc = intercomm -> c_coll .coll_iallreduce (inbuf , context -> tmpbuf , count , MPI_INT , op , intercomm ,
650
- & subreq , intercomm -> c_coll .coll_iallreduce_module );
643
+ rc = intercomm -> c_local_comm -> c_coll .coll_ireduce (inbuf , context -> tmpbuf , count , MPI_INT , op , 0 ,
644
+ intercomm -> c_local_comm , & subreq ,
645
+ intercomm -> c_local_comm -> c_coll .coll_ireduce_module );
651
646
if (OPAL_UNLIKELY (OMPI_SUCCESS != rc )) {
652
647
ompi_comm_request_return (request );
653
648
return rc ;
@@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
656
651
if (0 == local_rank ) {
657
652
ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_leader_exchange , & subreq , 1 );
658
653
} else {
659
- ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_allgather , & subreq , 1 );
654
+ ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_bcast , & subreq , 1 );
660
655
}
661
656
662
657
ompi_comm_request_start (request );
@@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request
696
691
697
692
ompi_op_reduce (context -> op , context -> tmpbuf , context -> outbuf , context -> count , MPI_INT );
698
693
699
- return ompi_comm_allreduce_inter_allgather (request );
694
+ return ompi_comm_allreduce_inter_bcast (request );
700
695
}
701
696
702
697
703
- static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t * request )
698
+ static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t * request )
704
699
{
705
700
ompi_comm_allreduce_context_t * context = (ompi_comm_allreduce_context_t * ) request -> context ;
706
- ompi_communicator_t * intercomm = context -> cid_context -> comm ;
701
+ ompi_communicator_t * comm = context -> cid_context -> comm -> c_local_comm ;
707
702
ompi_request_t * subreq ;
708
703
int scount = 0 , rc ;
709
704
710
- /* distribute the overall result to all processes in the other group.
711
- Instead of using bcast, we are using here allgatherv, to avoid the
712
- possible deadlock. Else, we need an algorithm to determine,
713
- which group sends first in the inter-bcast and which receives
714
- the result first.
715
- */
716
-
717
- if (0 != ompi_comm_rank (intercomm )) {
718
- context -> rcounts [0 ] = context -> count ;
719
- } else {
720
- scount = context -> count ;
721
- }
722
-
723
- rc = intercomm -> c_coll .coll_iallgatherv (context -> outbuf , scount , MPI_INT , context -> outbuf ,
724
- context -> rcounts , context -> rdisps , MPI_INT , intercomm ,
725
- & subreq , intercomm -> c_coll .coll_iallgatherv_module );
705
+ /* both roots have the same result. broadcast to the local group */
706
+ rc = comm -> c_coll .coll_ibcast (context -> outbuf , context -> count , MPI_INT , 0 , comm ,
707
+ & subreq , comm -> c_coll .coll_ibcast_module );
726
708
if (OMPI_SUCCESS != rc ) {
727
709
return rc ;
728
710
}
0 commit comments