diff --git a/ompi/mca/coll/hcoll/coll_hcoll.h b/ompi/mca/coll/hcoll/coll_hcoll.h index 25a60b3922..980f0ebb2d 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll.h +++ b/ompi/mca/coll/hcoll/coll_hcoll.h @@ -137,8 +137,12 @@ struct mca_coll_hcoll_module_t { mca_coll_base_module_t *previous_ibarrier_module; mca_coll_base_module_iallgather_fn_t previous_iallgather; mca_coll_base_module_t *previous_iallgather_module; + mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv; + mca_coll_base_module_t *previous_iallgatherv_module; mca_coll_base_module_iallreduce_fn_t previous_iallreduce; mca_coll_base_module_t *previous_iallreduce_module; + mca_coll_base_module_ireduce_fn_t previous_ireduce; + mca_coll_base_module_t *previous_ireduce_module; mca_coll_base_module_igatherv_fn_t previous_igatherv; mca_coll_base_module_t *previous_igatherv_module; mca_coll_base_module_ialltoall_fn_t previous_ialltoall; @@ -175,7 +179,15 @@ int mca_coll_hcoll_allgather(void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); -int mca_coll_hcoll_gather(void *sbuf, int scount, +int mca_coll_hcoll_allgatherv(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcount, + const int *displs, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_gather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -189,7 +201,14 @@ int mca_coll_hcoll_allreduce(void *sbuf, void *rbuf, int count, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); -int mca_coll_hcoll_alltoall(void *sbuf, int scount, +int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_alltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -231,14 +250,31 @@ int mca_coll_hcoll_iallgather(void *sbuf, int scount, ompi_request_t** request, mca_coll_base_module_t *module); -int mca_coll_hcoll_iallreduce(void *sbuf, void *rbuf, int count, +int mca_coll_hcoll_iallgatherv(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcount, + const int *displs, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_iallreduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + struct ompi_communicator_t *comm, + ompi_request_t** request, + mca_coll_base_module_t *module); + +int mca_coll_hcoll_ireduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, + int root, struct ompi_communicator_t *comm, ompi_request_t** request, mca_coll_base_module_t *module); -int mca_coll_hcoll_ialltoall(void *sbuf, int scount, +int mca_coll_hcoll_ialltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -256,7 +292,7 @@ int mca_coll_hcoll_ialltoallv(void *sbuf, int *scounts, ompi_request_t **req, mca_coll_base_module_t *module); -int mca_coll_hcoll_igatherv(void* sbuf, int scount, +int mca_coll_hcoll_igatherv(const void* sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int *rcounts, int *displs, struct ompi_datatype_t *rdtype, diff --git a/ompi/mca/coll/hcoll/coll_hcoll_module.c b/ompi/mca/coll/hcoll/coll_hcoll_module.c index f1af6931b1..3a118b3d0b 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_module.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_module.c @@ -41,12 +41,15 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module) hcoll_module->previous_alltoall = NULL; hcoll_module->previous_alltoallv = NULL; hcoll_module->previous_alltoallw = NULL; + hcoll_module->previous_reduce = NULL; hcoll_module->previous_reduce_scatter = NULL; hcoll_module->previous_ibarrier = NULL; hcoll_module->previous_ibcast = NULL; hcoll_module->previous_iallreduce = NULL; hcoll_module->previous_iallgather = NULL; + hcoll_module->previous_iallgatherv = NULL; hcoll_module->previous_igatherv = NULL; + hcoll_module->previous_ireduce = NULL; } static void mca_coll_hcoll_module_construct(mca_coll_hcoll_module_t *hcoll_module) @@ -80,17 +83,21 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module OBJ_RELEASE(hcoll_module->previous_bcast_module); OBJ_RELEASE(hcoll_module->previous_allreduce_module); OBJ_RELEASE(hcoll_module->previous_allgather_module); + OBJ_RELEASE(hcoll_module->previous_allgatherv_module); OBJ_RELEASE(hcoll_module->previous_gatherv_module); OBJ_RELEASE(hcoll_module->previous_alltoall_module); OBJ_RELEASE(hcoll_module->previous_alltoallv_module); + OBJ_RELEASE(hcoll_module->previous_reduce_module); OBJ_RELEASE(hcoll_module->previous_ibarrier_module); OBJ_RELEASE(hcoll_module->previous_ibcast_module); OBJ_RELEASE(hcoll_module->previous_iallreduce_module); OBJ_RELEASE(hcoll_module->previous_iallgather_module); + OBJ_RELEASE(hcoll_module->previous_iallgatherv_module); OBJ_RELEASE(hcoll_module->previous_igatherv_module); OBJ_RELEASE(hcoll_module->previous_ialltoall_module); OBJ_RELEASE(hcoll_module->previous_ialltoallv_module); + OBJ_RELEASE(hcoll_module->previous_ireduce_module); /* OBJ_RELEASE(hcoll_module->previous_allgatherv_module); @@ -127,7 +134,9 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu HCOL_SAVE_PREV_COLL_API(barrier); HCOL_SAVE_PREV_COLL_API(bcast); HCOL_SAVE_PREV_COLL_API(allreduce); + HCOL_SAVE_PREV_COLL_API(reduce); HCOL_SAVE_PREV_COLL_API(allgather); + HCOL_SAVE_PREV_COLL_API(allgatherv); HCOL_SAVE_PREV_COLL_API(gatherv); HCOL_SAVE_PREV_COLL_API(alltoall); HCOL_SAVE_PREV_COLL_API(alltoallv); @@ -135,7 +144,9 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu HCOL_SAVE_PREV_COLL_API(ibarrier); HCOL_SAVE_PREV_COLL_API(ibcast); HCOL_SAVE_PREV_COLL_API(iallreduce); + HCOL_SAVE_PREV_COLL_API(ireduce); HCOL_SAVE_PREV_COLL_API(iallgather); + HCOL_SAVE_PREV_COLL_API(iallgatherv); HCOL_SAVE_PREV_COLL_API(igatherv); HCOL_SAVE_PREV_COLL_API(ialltoall); HCOL_SAVE_PREV_COLL_API(ialltoallv); @@ -312,14 +323,26 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority) hcoll_module->super.coll_barrier = hcoll_collectives.coll_barrier ? mca_coll_hcoll_barrier : NULL; hcoll_module->super.coll_bcast = hcoll_collectives.coll_bcast ? mca_coll_hcoll_bcast : NULL; hcoll_module->super.coll_allgather = hcoll_collectives.coll_allgather ? mca_coll_hcoll_allgather : NULL; + hcoll_module->super.coll_allgatherv = hcoll_collectives.coll_allgatherv ? mca_coll_hcoll_allgatherv : NULL; hcoll_module->super.coll_allreduce = hcoll_collectives.coll_allreduce ? mca_coll_hcoll_allreduce : NULL; hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL; hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL; hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL; + hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL; hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL; hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL; hcoll_module->super.coll_iallgather = hcoll_collectives.coll_iallgather ? mca_coll_hcoll_iallgather : NULL; +#if HCOLL_API >= HCOLL_VERSION(3,5) + hcoll_module->super.coll_iallgatherv = hcoll_collectives.coll_iallgatherv ? mca_coll_hcoll_iallgatherv : NULL; +#else + hcoll_module->super.coll_iallgatherv = NULL; +#endif hcoll_module->super.coll_iallreduce = hcoll_collectives.coll_iallreduce ? mca_coll_hcoll_iallreduce : NULL; +#if HCOLL_API >= HCOLL_VERSION(3,5) + hcoll_module->super.coll_ireduce = hcoll_collectives.coll_ireduce ? mca_coll_hcoll_ireduce : NULL; +#else + hcoll_module->super.coll_ireduce = NULL; +#endif hcoll_module->super.coll_gather = /*hcoll_collectives.coll_gather ? mca_coll_hcoll_gather :*/ NULL; hcoll_module->super.coll_igatherv = hcoll_collectives.coll_igatherv ? mca_coll_hcoll_igatherv : NULL; hcoll_module->super.coll_ialltoall = /*hcoll_collectives.coll_ialltoall ? mca_coll_hcoll_ialltoall : */ NULL; diff --git a/ompi/mca/coll/hcoll/coll_hcoll_ops.c b/ompi/mca/coll/hcoll/coll_hcoll_ops.c index 32e7612d2f..2617759272 100644 --- a/ompi/mca/coll/hcoll/coll_hcoll_ops.c +++ b/ompi/mca/coll/hcoll/coll_hcoll_ops.c @@ -97,7 +97,52 @@ int mca_coll_hcoll_allgather(void *sbuf, int scount, return rc; } -int mca_coll_hcoll_gather(void *sbuf, int scount, +int mca_coll_hcoll_allgatherv(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcount, + const int *displs, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + dte_data_representation_t stype; + dte_data_representation_t rtype; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL ALLGATHERV"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + stype = ompi_dtype_2_dte_dtype(sdtype); + rtype = ompi_dtype_2_dte_dtype(rdtype); + if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) + || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) + && mca_coll_hcoll_component.hcoll_datatype_fallback){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback allgatherv;", + sdtype->super.name, + rdtype->super.name); + rc = hcoll_module->previous_allgatherv(sbuf,scount,sdtype, + rbuf,rcount, + displs, + rdtype, + comm, + hcoll_module->previous_allgatherv_module); + return rc; + } + rc = hcoll_collectives.coll_allgatherv((void *)sbuf,scount,stype,rbuf,rcount,displs,rtype,hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK ALLGATHERV"); + rc = hcoll_module->previous_allgatherv(sbuf,scount,sdtype, + rbuf,rcount, + displs, + rdtype, + comm, + hcoll_module->previous_allgatherv_module); + } + return rc; +} + +int mca_coll_hcoll_gather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -186,7 +231,59 @@ int mca_coll_hcoll_allreduce(void *sbuf, void *rbuf, int count, return rc; } -int mca_coll_hcoll_alltoall(void *sbuf, int scount, +int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + dte_data_representation_t Dtype; + hcoll_dte_op_t *Op; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL REDUCE"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + Dtype = ompi_dtype_2_dte_dtype(dtype); + if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype))) + && mca_coll_hcoll_component.hcoll_datatype_fallback){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback reduce;", + dtype->super.name); + rc = hcoll_module->previous_reduce(sbuf,rbuf, + count,dtype,op, + root, + comm, hcoll_module->previous_reduce_module); + return rc; + } + + Op = ompi_op_2_hcolrte_op(op); + if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback reduce;", + op->o_name); + rc = hcoll_module->previous_reduce(sbuf,rbuf, + count,dtype,op, + root, + comm, hcoll_module->previous_reduce_module); + return rc; + } + + rc = hcoll_collectives.coll_reduce((void *)sbuf,rbuf,count,Dtype,Op,root,hcoll_module->hcoll_context); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK REDUCE"); + rc = hcoll_module->previous_reduce(sbuf,rbuf, + count,dtype,op, + root, + comm, hcoll_module->previous_reduce_module); + } + return rc; +} + +int mca_coll_hcoll_alltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int rcount, struct ompi_datatype_t *rdtype, @@ -397,7 +494,58 @@ int mca_coll_hcoll_iallgather(void *sbuf, int scount, return rc; } -int mca_coll_hcoll_iallreduce(void *sbuf, void *rbuf, int count, +#if HCOLL_API >= HCOLL_VERSION(3,5) +int mca_coll_hcoll_iallgatherv(const void *sbuf, int scount, + struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcount, + const int *displs, + struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, + ompi_request_t ** request, + mca_coll_base_module_t *module) +{ + dte_data_representation_t stype; + dte_data_representation_t rtype; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING ALLGATHERV"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + stype = ompi_dtype_2_dte_dtype(sdtype); + rtype = ompi_dtype_2_dte_dtype(rdtype); + void **rt_handle = (void **) request; + if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype) + || HCOL_DTE_IS_COMPLEX(stype) || HCOL_DTE_IS_COMPLEX(rtype))) + && mca_coll_hcoll_component.hcoll_datatype_fallback){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback non-blocking allgatherv;", + sdtype->super.name, + rdtype->super.name); + rc = hcoll_module->previous_iallgatherv(sbuf,scount,sdtype, + rbuf,rcount, + displs, + rdtype, + comm, + request, + hcoll_module->previous_iallgatherv_module); + return rc; + } + rc = hcoll_collectives.coll_iallgatherv((void *)sbuf,scount,stype,rbuf,rcount,displs,rtype, + hcoll_module->hcoll_context, rt_handle); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING ALLGATHER"); + rc = hcoll_module->previous_iallgatherv(sbuf,scount,sdtype, + rbuf,rcount, + displs, + rdtype, + comm, + request, + hcoll_module->previous_iallgatherv_module); + } + return rc; +} +#endif +int mca_coll_hcoll_iallreduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -447,8 +595,65 @@ int mca_coll_hcoll_iallreduce(void *sbuf, void *rbuf, int count, } return rc; } +#if HCOLL_API >= HCOLL_VERSION(3,5) +int mca_coll_hcoll_ireduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + int root, + struct ompi_communicator_t *comm, + ompi_request_t ** request, + mca_coll_base_module_t *module) +{ + dte_data_representation_t Dtype; + hcoll_dte_op_t *Op; + int rc; + HCOL_VERBOSE(20,"RUNNING HCOL NON-BLOCKING REDUCE"); + mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module; + Dtype = ompi_dtype_2_dte_dtype(dtype); + void **rt_handle = (void**) request; + if (OPAL_UNLIKELY((HCOL_DTE_IS_ZERO(Dtype) || HCOL_DTE_IS_COMPLEX(Dtype))) + && mca_coll_hcoll_component.hcoll_datatype_fallback){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback non-blocking reduce;", + dtype->super.name); + rc = hcoll_module->previous_ireduce(sbuf,rbuf,count,dtype,op, + root, + comm, request, + hcoll_module->previous_ireduce_module); + return rc; + } -int mca_coll_hcoll_igatherv(void* sbuf, int scount, + Op = ompi_op_2_hcolrte_op(op); + if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){ + /*If we are here then datatype is not simple predefined datatype */ + /*In future we need to add more complex mapping to the dte_data_representation_t */ + /* Now use fallback */ + HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback non-blocking reduce;", + op->o_name); + rc = hcoll_module->previous_ireduce(sbuf,rbuf, + count,dtype,op, + root, + comm, request, + hcoll_module->previous_ireduce_module); + return rc; + } + + rc = hcoll_collectives.coll_ireduce((void *)sbuf,rbuf,count,Dtype,Op,root,hcoll_module->hcoll_context,rt_handle); + if (HCOLL_SUCCESS != rc){ + HCOL_VERBOSE(20,"RUNNING FALLBACK NON-BLOCKING REDUCE"); + rc = hcoll_module->previous_ireduce(sbuf,rbuf, + count,dtype,op, + root, + comm, + request, + hcoll_module->previous_ireduce_module); + } + return rc; +} +#endif +int mca_coll_hcoll_igatherv(const void* sbuf, int scount, struct ompi_datatype_t *sdtype, void* rbuf, int *rcounts, int *displs, struct ompi_datatype_t *rdtype,