diff --git a/ompi/mca/coll/accelerator/coll_accelerator.h b/ompi/mca/coll/accelerator/coll_accelerator.h index b170e38f268..e707d7ec7f2 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator.h +++ b/ompi/mca/coll/accelerator/coll_accelerator.h @@ -1,4 +1,5 @@ /* + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2014 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. @@ -45,6 +46,11 @@ mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, size_t count, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +int mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + mca_coll_base_module_t *module); + int mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, diff --git a/ompi/mca/coll/accelerator/coll_accelerator_module.c b/ompi/mca/coll/accelerator/coll_accelerator_module.c index 4fe1603a8aa..4005f6cdec9 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_module.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_module.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2014-2017 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. @@ -94,6 +95,7 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce; accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce; + accelerator_module->super.coll_reduce_local = mca_coll_accelerator_reduce_local; accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block; if (!OMPI_COMM_IS_INTER(comm)) { accelerator_module->super.coll_scan = mca_coll_accelerator_scan; @@ -141,6 +143,7 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local); ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); if (!OMPI_COMM_IS_INTER(comm)) { /* MPI does not define scan/exscan on intercommunicators */ @@ -159,6 +162,7 @@ mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce); ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce); + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_local); ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block); if (!OMPI_COMM_IS_INTER(comm)) { diff --git a/ompi/mca/coll/accelerator/coll_accelerator_reduce.c b/ompi/mca/coll/accelerator/coll_accelerator_reduce.c index 6b0d3d5d72b..993271fa16b 100644 --- a/ompi/mca/coll/accelerator/coll_accelerator_reduce.c +++ b/ompi/mca/coll/accelerator/coll_accelerator_reduce.c @@ -1,4 +1,5 @@ /* + * Copyright (c) 2024 NVIDIA Corporation. All rights reserved. * Copyright (c) 2004-2023 The University of Tennessee and The University * of Tennessee Research Foundation. All rights * reserved. @@ -84,3 +85,60 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count, } return rc; } + +int +mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + mca_coll_base_module_t *module) +{ + ptrdiff_t gap; + char *rbuf1 = NULL, *sbuf1 = NULL, *rbuf2 = NULL; + size_t bufsize; + int rc; + + bufsize = opal_datatype_span(&dtype->super, count, &gap); + + rc = mca_coll_accelerator_check_buf((void *)sbuf); + if (rc < 0) { + return rc; + } + + if ((MPI_IN_PLACE != sbuf) && (rc > 0)) { + sbuf1 = (char*)malloc(bufsize); + if (NULL == sbuf1) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + mca_coll_accelerator_memcpy(sbuf1, sbuf, bufsize); + sbuf = sbuf1 - gap; + } + + rc = mca_coll_accelerator_check_buf(rbuf); + if (rc < 0) { + return rc; + } + + if (rc > 0) { + rbuf1 = (char*)malloc(bufsize); + if (NULL == rbuf1) { + if (NULL != sbuf1) free(sbuf1); + return OMPI_ERR_OUT_OF_RESOURCE; + } + mca_coll_accelerator_memcpy(rbuf1, rbuf, bufsize); + rbuf2 = rbuf; /* save away original buffer */ + rbuf = rbuf1 - gap; + } + + ompi_op_reduce(op, (void *)sbuf, rbuf, count, dtype); + rc = OMPI_SUCCESS; + + if (NULL != sbuf1) { + free(sbuf1); + } + if (NULL != rbuf1) { + rbuf = rbuf2; + mca_coll_accelerator_memcpy(rbuf, rbuf1, bufsize); + free(rbuf1); + } + return rc; +}