Skip to content

Commit 36bcc48

Browse files
authored
Merge pull request #7902 from vspetrov/v4.1.x_hcoll_reduce_scatter
V4.1.x hcoll reduce scatter
2 parents 6b4e2f0 + 6f40118 commit 36bcc48

File tree

3 files changed

+119
-8
lines changed

3 files changed

+119
-8
lines changed

ompi/mca/coll/hcoll/coll_hcoll.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ struct mca_coll_hcoll_module_t {
142142
mca_coll_base_module_t *previous_scatterv_module;
143143
mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter;
144144
mca_coll_base_module_t *previous_reduce_scatter_module;
145+
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
146+
mca_coll_base_module_t *previous_reduce_scatter_block_module;
145147
mca_coll_base_module_ibcast_fn_t previous_ibcast;
146148
mca_coll_base_module_t *previous_ibcast_module;
147149
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
@@ -212,6 +214,18 @@ int mca_coll_hcoll_allreduce(const void *sbuf, void *rbuf, int count,
212214
struct ompi_communicator_t *comm,
213215
mca_coll_base_module_t *module);
214216

217+
#if HCOLL_API > HCOLL_VERSION(4,5)
218+
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
219+
struct ompi_datatype_t *dtype,
220+
struct ompi_op_t *op,
221+
struct ompi_communicator_t *comm,
222+
mca_coll_base_module_t *module);
223+
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
224+
struct ompi_datatype_t *dtype,
225+
struct ompi_op_t *op,
226+
struct ompi_communicator_t *comm,
227+
mca_coll_base_module_t *module);
228+
#endif
215229
int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count,
216230
struct ompi_datatype_t *dtype,
217231
struct ompi_op_t *op,
@@ -303,11 +317,11 @@ int mca_coll_hcoll_ialltoall(const void *sbuf, int scount,
303317
mca_coll_base_module_t *module);
304318

305319
#if HCOLL_API >= HCOLL_VERSION(3,7)
306-
int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts,
307-
int *sdisps,
320+
int mca_coll_hcoll_ialltoallv(const void *sbuf, const int *scounts,
321+
const int *sdisps,
308322
struct ompi_datatype_t *sdtype,
309-
void *rbuf, int *rcounts,
310-
int *rdisps,
323+
void *rbuf, const int *rcounts,
324+
const int *rdisps,
311325
struct ompi_datatype_t *rdtype,
312326
struct ompi_communicator_t *comm,
313327
ompi_request_t **req,

ompi/mca/coll/hcoll/coll_hcoll_module.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module)
5151
hcoll_module->previous_alltoallw = NULL;
5252
hcoll_module->previous_reduce = NULL;
5353
hcoll_module->previous_reduce_scatter = NULL;
54+
hcoll_module->previous_reduce_scatter_block = NULL;
5455
hcoll_module->previous_ibarrier = NULL;
5556
hcoll_module->previous_ibcast = NULL;
5657
hcoll_module->previous_iallreduce = NULL;
@@ -119,6 +120,8 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module
119120
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module);
120121
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module);
121122
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module);
123+
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module);
124+
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module);
122125
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module);
123126
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module);
124127
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module);
@@ -173,6 +176,8 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu
173176
HCOL_SAVE_PREV_COLL_API(barrier);
174177
HCOL_SAVE_PREV_COLL_API(bcast);
175178
HCOL_SAVE_PREV_COLL_API(allreduce);
179+
HCOL_SAVE_PREV_COLL_API(reduce_scatter_block);
180+
HCOL_SAVE_PREV_COLL_API(reduce_scatter);
176181
HCOL_SAVE_PREV_COLL_API(reduce);
177182
HCOL_SAVE_PREV_COLL_API(allgather);
178183
HCOL_SAVE_PREV_COLL_API(allgatherv);
@@ -419,6 +424,12 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
419424
hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL;
420425
#else
421426
hcoll_module->super.coll_ialltoallv = NULL;
427+
#endif
428+
#if HCOLL_API > HCOLL_VERSION(4,5)
429+
hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ?
430+
mca_coll_hcoll_reduce_scatter_block : NULL;
431+
hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ?
432+
mca_coll_hcoll_reduce_scatter : NULL;
422433
#endif
423434
*priority = cm->hcoll_priority;
424435
module = &hcoll_module->super;

ompi/mca/coll/hcoll/coll_hcoll_ops.c

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ int mca_coll_hcoll_allgatherv(const void *sbuf, int scount,
136136
hcoll_module->previous_allgatherv_module);
137137
return rc;
138138
}
139-
rc = hcoll_collectives.coll_allgatherv((void *)sbuf,scount,stype,rbuf,rcount,displs,rtype,hcoll_module->hcoll_context);
139+
rc = hcoll_collectives.coll_allgatherv((void *)sbuf,scount,stype,rbuf,(int*)rcount,
140+
(int*)displs,rtype,hcoll_module->hcoll_context);
140141
if (HCOLL_SUCCESS != rc){
141142
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLGATHERV");
142143
rc = hcoll_module->previous_allgatherv(sbuf,scount,sdtype,
@@ -558,7 +559,7 @@ int mca_coll_hcoll_iallgatherv(const void *sbuf, int scount,
558559
hcoll_module->previous_iallgatherv_module);
559560
return rc;
560561
}
561-
rc = hcoll_collectives.coll_iallgatherv((void *)sbuf,scount,stype,rbuf,rcount,displs,rtype,
562+
rc = hcoll_collectives.coll_iallgatherv((void *)sbuf,scount,stype,rbuf,(int*)rcount,(int*)displs,rtype,
562563
hcoll_module->hcoll_context, rt_handle);
563564
if (HCOLL_SUCCESS != rc){
564565
HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING ALLGATHER");
@@ -724,9 +725,9 @@ int mca_coll_hcoll_igatherv(const void* sbuf, int scount,
724725

725726

726727
#if HCOLL_API >= HCOLL_VERSION(3,7)
727-
int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
728+
int mca_coll_hcoll_ialltoallv(const void *sbuf, const int *scounts, const int *sdisps,
728729
struct ompi_datatype_t *sdtype,
729-
void *rbuf, int *rcounts, int *rdisps,
730+
void *rbuf, const int *rcounts, const int *rdisps,
730731
struct ompi_datatype_t *rdtype,
731732
struct ompi_communicator_t *comm,
732733
ompi_request_t ** request,
@@ -760,3 +761,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
760761
return rc;
761762
}
762763
#endif
764+
765+
#if HCOLL_API > HCOLL_VERSION(4,5)
766+
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
767+
struct ompi_datatype_t *dtype,
768+
struct ompi_op_t *op,
769+
struct ompi_communicator_t *comm,
770+
mca_coll_base_module_t *module) {
771+
dte_data_representation_t Dtype;
772+
hcoll_dte_op_t *Op;
773+
int rc;
774+
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER BLOCK");
775+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
776+
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
777+
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
778+
/*If we are here then datatype is not simple predefined datatype */
779+
/*In future we need to add more complex mapping to the dte_data_representation_t */
780+
/* Now use fallback */
781+
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
782+
dtype->super.name);
783+
goto fallback;
784+
}
785+
786+
Op = ompi_op_2_hcolrte_op(op);
787+
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
788+
/*If we are here then datatype is not simple predefined datatype */
789+
/*In future we need to add more complex mapping to the dte_data_representation_t */
790+
/* Now use fallback */
791+
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
792+
op->o_name);
793+
goto fallback;
794+
}
795+
796+
rc = hcoll_collectives.coll_reduce_scatter_block((void *)sbuf,rbuf,rcount,Dtype,Op,hcoll_module->hcoll_context);
797+
if (HCOLL_SUCCESS != rc){
798+
fallback:
799+
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
800+
rc = hcoll_module->previous_reduce_scatter_block(sbuf,rbuf,
801+
rcount,dtype,op,
802+
comm, hcoll_module->previous_allreduce_module);
803+
}
804+
return rc;
805+
}
806+
807+
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
808+
struct ompi_datatype_t *dtype,
809+
struct ompi_op_t *op,
810+
struct ompi_communicator_t *comm,
811+
mca_coll_base_module_t *module) {
812+
dte_data_representation_t Dtype;
813+
hcoll_dte_op_t *Op;
814+
int rc;
815+
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER");
816+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
817+
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
818+
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
819+
/*If we are here then datatype is not simple predefined datatype */
820+
/*In future we need to add more complex mapping to the dte_data_representation_t */
821+
/* Now use fallback */
822+
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
823+
dtype->super.name);
824+
goto fallback;
825+
}
826+
827+
Op = ompi_op_2_hcolrte_op(op);
828+
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
829+
/*If we are here then datatype is not simple predefined datatype */
830+
/*In future we need to add more complex mapping to the dte_data_representation_t */
831+
/* Now use fallback */
832+
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
833+
op->o_name);
834+
goto fallback;
835+
}
836+
837+
rc = hcoll_collectives.coll_reduce_scatter((void*)sbuf, rbuf, (int*)rcounts,
838+
Dtype, Op, hcoll_module->hcoll_context);
839+
if (HCOLL_SUCCESS != rc){
840+
fallback:
841+
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
842+
rc = hcoll_module->previous_reduce_scatter(sbuf,rbuf,
843+
rcounts,dtype,op,
844+
comm, hcoll_module->previous_allreduce_module);
845+
}
846+
return rc;
847+
}
848+
#endif

0 commit comments

Comments
 (0)