1212 * Copyright (c) 2008 Sun Microsystems, Inc. All rights reserved.
1313 * Copyright (c) 2012 Oak Ridge National Labs. All rights reserved.
1414 * Copyright (c) 2012 Sandia National Laboratories. All rights reserved.
15- * Copyright (c) 2014-2015 Research Organization for Information Science
15+ * Copyright (c) 2014-2016 Research Organization for Information Science
1616 * and Technology (RIST). All rights reserved.
1717 * $COPYRIGHT$
1818 *
@@ -58,7 +58,7 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
5858 mca_coll_base_module_t * module )
5959{
6060 int rank , size , count , err = OMPI_SUCCESS ;
61- ptrdiff_t extent , buf_size , gap ;
61+ ptrdiff_t gap , span ;
6262 char * recv_buf = NULL , * recv_buf_free = NULL ;
6363
6464 /* Initialize */
@@ -72,8 +72,7 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
7272 }
7373
7474 /* get datatype information */
75- ompi_datatype_type_extent (dtype , & extent );
76- buf_size = opal_datatype_span (& dtype -> super , count , & gap );
75+ span = opal_datatype_span (& dtype -> super , count , & gap );
7776
7877 /* Handle MPI_IN_PLACE */
7978 if (MPI_IN_PLACE == sbuf ) {
@@ -83,12 +82,12 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
8382 if (0 == rank ) {
8483 /* temporary receive buffer. See coll_basic_reduce.c for
8584 details on sizing */
86- recv_buf_free = (char * ) malloc (buf_size );
87- recv_buf = recv_buf_free - gap ;
85+ recv_buf_free = (char * ) malloc (span );
8886 if (NULL == recv_buf_free ) {
8987 err = OMPI_ERR_OUT_OF_RESOURCE ;
9088 goto cleanup ;
9189 }
90+ recv_buf = recv_buf_free - gap ;
9291 }
9392
9493 /* reduction */
@@ -126,8 +125,9 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
126125{
127126 int err , i , rank , root = 0 , rsize , lsize ;
128127 int totalcounts ;
129- ptrdiff_t lb , extent ;
128+ ptrdiff_t gap , span ;
130129 char * tmpbuf = NULL , * tmpbuf2 = NULL ;
130+ char * lbuf , * buf ;
131131 ompi_request_t * req ;
132132
133133 rank = ompi_comm_rank (comm );
@@ -151,16 +151,15 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
151151 *
152152 */
153153 if (rank == root ) {
154- err = ompi_datatype_get_extent (dtype , & lb , & extent );
155- if (OMPI_SUCCESS != err ) {
156- return OMPI_ERROR ;
157- }
154+ span = opal_datatype_span (& dtype -> super , totalcounts , & gap );
158155
159- tmpbuf = (char * ) malloc (totalcounts * extent );
160- tmpbuf2 = (char * ) malloc (totalcounts * extent );
156+ tmpbuf = (char * ) malloc (span );
157+ tmpbuf2 = (char * ) malloc (span );
161158 if (NULL == tmpbuf || NULL == tmpbuf2 ) {
162159 return OMPI_ERR_OUT_OF_RESOURCE ;
163160 }
161+ lbuf = tmpbuf - gap ;
162+ buf = tmpbuf2 - gap ;
164163
165164 /* Do a send-recv between the two root procs. to avoid deadlock */
166165 err = MCA_PML_CALL (isend (sbuf , totalcounts , dtype , 0 ,
@@ -170,7 +169,7 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
170169 goto exit ;
171170 }
172171
173- err = MCA_PML_CALL (recv (tmpbuf2 , totalcounts , dtype , 0 ,
172+ err = MCA_PML_CALL (recv (lbuf , totalcounts , dtype , 0 ,
174173 MCA_COLL_BASE_TAG_REDUCE_SCATTER , comm ,
175174 MPI_STATUS_IGNORE ));
176175 if (OMPI_SUCCESS != err ) {
@@ -188,15 +187,18 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
188187 * tmpbuf2.
189188 */
190189 for (i = 1 ; i < rsize ; i ++ ) {
191- err = MCA_PML_CALL (recv (tmpbuf , totalcounts , dtype , i ,
190+ char * tbuf ;
191+ err = MCA_PML_CALL (recv (buf , totalcounts , dtype , i ,
192192 MCA_COLL_BASE_TAG_REDUCE_SCATTER , comm ,
193193 MPI_STATUS_IGNORE ));
194194 if (MPI_SUCCESS != err ) {
195195 goto exit ;
196196 }
197197
198198 /* Perform the reduction */
199- ompi_op_reduce (op , tmpbuf , tmpbuf2 , totalcounts , dtype );
199+ ompi_op_reduce (op , lbuf , buf , totalcounts , dtype );
200+ /* swap the buffers */
201+ tbuf = lbuf ; lbuf = buf ; buf = tbuf ;
200202 }
201203 } else {
202204 /* If not root, send data to the root. */
@@ -209,7 +211,7 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
209211 }
210212
211213 /* Now do a scatterv on the local communicator */
212- err = comm -> c_local_comm -> c_coll .coll_scatter (tmpbuf2 , rcount , dtype ,
214+ err = comm -> c_local_comm -> c_coll .coll_scatter (lbuf , rcount , dtype ,
213215 rbuf , rcount , dtype , 0 ,
214216 comm -> c_local_comm ,
215217 comm -> c_local_comm -> c_coll .coll_scatter_module );
0 commit comments