From d0a4e0775fc6c015df60a0adee5bbf1badbe02e5 Mon Sep 17 00:00:00 2001 From: Roie Danino Date: Sun, 6 Apr 2025 17:32:41 +0300 Subject: [PATCH] OSHMEM/MCA/SPML/UCX: added support for team management functions Signed-off-by: Roie Danino --- oshmem/mca/spml/ucx/spml_ucx.c | 153 ++++++++++++++++++++++++++++++--- oshmem/mca/spml/ucx/spml_ucx.h | 23 +++++ oshmem/shmem/c/shmem_team.c | 15 +--- 3 files changed, 165 insertions(+), 26 deletions(-) diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 1b2fca7d6f1..40b20ff7bff 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -1755,53 +1755,178 @@ int mca_spml_ucx_team_sync(shmem_team_t team) return OSHMEM_ERR_NOT_IMPLEMENTED; } -/* This routine is not implemented */ int mca_spml_ucx_team_my_pe(shmem_team_t team) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team; + + if (team == SHMEM_TEAM_WORLD) { + return shmem_my_pe(); + } + + return ucx_team->my_pe; } -/* This routine is not implemented */ int mca_spml_ucx_team_n_pes(shmem_team_t team) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team; + + if (team == SHMEM_TEAM_WORLD) { + return shmem_n_pes(); + } + + return ucx_team->n_pes; } -/* This routine is not implemented */ int mca_spml_ucx_team_get_config(shmem_team_t team, long config_mask, shmem_team_config_t *config) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team; + SPML_UCX_VALIDATE_TEAM(team); + + memcpy(config, &ucx_team->config, sizeof(shmem_team_config_t)); + + return SHMEM_SUCCESS; +} + +static inline int mca_spml_ucx_is_pe_in_strided_team(int src_pe, int start, + int stride, int size) +{ + return (src_pe >= start) && (src_pe < start + size * stride) + && ((src_pe - start) % stride == 0); } -/* This routine is not implemented */ int mca_spml_ucx_team_translate_pe(shmem_team_t src_team, int src_pe, - shmem_team_t dest_team) + shmem_team_t dest_team) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_src_team = (mca_spml_ucx_team_t*) src_team; + mca_spml_ucx_team_t *ucx_dest_team = (mca_spml_ucx_team_t*) dest_team; + int global_pe; + + if ((src_pe == SPML_UCX_PE_NOT_IN_TEAM) || (src_team == dest_team)) { + return src_pe; + } + + global_pe = ucx_src_team->start + src_pe * ucx_src_team->stride; + + if (dest_team == SHMEM_TEAM_WORLD) { + return global_pe; + } + + if (!mca_spml_ucx_is_pe_in_strided_team(global_pe, ucx_dest_team->start, ucx_dest_team->stride, + ucx_dest_team->n_pes)) { + return SPML_UCX_PE_NOT_IN_TEAM; + } + + return (global_pe - ucx_dest_team->start) / ucx_dest_team->stride; } -/* This routine is not implemented */ int mca_spml_ucx_team_split_strided(shmem_team_t parent_team, int start, int stride, int size, const shmem_team_config_t *config, long config_mask, shmem_team_t *new_team) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_parent_team; + mca_spml_ucx_team_t *ucx_new_team; + int parent_pe; + int parent_start; + int parent_stride; + int my_pe; + + SPML_UCX_ASSERT(((start + size * stride) <= oshmem_num_procs()) && + (stride > 0) && (size > 0)); + + if (parent_team == SHMEM_TEAM_WORLD) { + parent_pe = shmem_my_pe(); + parent_start = 0; + parent_stride = 1; + } else { + ucx_parent_team = (mca_spml_ucx_team_t*) parent_team; + parent_pe = ucx_parent_team->my_pe; + parent_start = ucx_parent_team->start; + parent_stride = ucx_parent_team->stride; + } + + if (mca_spml_ucx_is_pe_in_strided_team(parent_pe, start, stride, size)) { + my_pe = (parent_pe - start) / stride; + } else { + /* not in team, according to spec it should be SHMEM_TEAM_INVALID but its value is NULL which + can be also interpreted as 0 (first pe), therefore -1 is used */ + my_pe = SPML_UCX_PE_NOT_IN_TEAM; + } + + /* In order to simplify pe translations start and stride are calculated with respect to + * world_team */ + ucx_new_team = (mca_spml_ucx_team_t *)malloc(sizeof(mca_spml_ucx_team_t)); + ucx_new_team->start = parent_start + (start * parent_stride); + ucx_new_team->stride = parent_stride * stride; + + ucx_new_team->n_pes = size; + ucx_new_team->my_pe = my_pe; + + ucx_new_team->config = calloc(1, sizeof(mca_spml_ucx_team_config_t)); + + if (config != NULL) { + memcpy(&ucx_new_team->config->super, config, sizeof(shmem_team_config_t)); + } + + ucx_new_team->parent_team = (mca_spml_ucx_team_t*)parent_team; + + *new_team = (shmem_team_t)ucx_new_team; + + return OSHMEM_SUCCESS; } -/* This routine is not implemented */ int mca_spml_ucx_team_split_2d(shmem_team_t parent_team, int xrange, const shmem_team_config_t *xaxis_config, long xaxis_mask, shmem_team_t *xaxis_team, const shmem_team_config_t *yaxis_config, long yaxis_mask, shmem_team_t *yaxis_team) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_parent_team = (mca_spml_ucx_team_t*) parent_team; + int parent_n_pes = (parent_team == SHMEM_TEAM_WORLD) ? + oshmem_num_procs() : + ucx_parent_team->n_pes; + int parent_my_pe = (parent_team == SHMEM_TEAM_WORLD) ? + shmem_my_pe() : + ucx_parent_team->my_pe; + int yrange = parent_n_pes / xrange; + int pe_x = parent_my_pe % xrange; + int pe_y = parent_my_pe / xrange; + int rc; + + /* Create x-team of my_pe */ + rc = mca_spml_ucx_team_split_strided(parent_team, pe_y * xrange, 1, xrange, + xaxis_config, xaxis_mask, xaxis_team); + + if (rc != OSHMEM_SUCCESS) { + SPML_UCX_ERROR("mca_spml_ucx_team_split_strided failed (x-axis team creation)"); + return rc; + } + + /* Create y-team of my_pe */ + rc = mca_spml_ucx_team_split_strided(parent_team, pe_x, xrange, yrange, + yaxis_config, yaxis_mask, yaxis_team); + if (rc != OSHMEM_SUCCESS) { + SPML_UCX_ERROR("mca_spml_ucx_team_split_strided failed (y-axis team creation)"); + goto out_free_xaxis; + } + + return OSHMEM_SUCCESS; + +out_free_xaxis: + mca_spml_ucx_team_destroy(*xaxis_team); + return rc; } /* This routine is not implemented */ int mca_spml_ucx_team_destroy(shmem_team_t team) { - return OSHMEM_ERR_NOT_IMPLEMENTED; + mca_spml_ucx_team_t *ucx_team = (mca_spml_ucx_team_t *)team; + + SPML_UCX_VALIDATE_TEAM(team); + + free(ucx_team->config); + free(team); + + return OSHMEM_SUCCESS; } /* This routine is not implemented */ diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index 9d36c14cd7d..0d1f345e925 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -48,6 +48,15 @@ BEGIN_C_DECLS #define SPML_UCX_TRANSP_IDX 0 #define SPML_UCX_TRANSP_CNT 1 #define SPML_UCX_SERVICE_SEG 0 +#define SPML_UCX_PE_NOT_IN_TEAM -1 + +#define SPML_UCX_VALIDATE_TEAM(_team) \ + do { \ + if (OPAL_UNLIKELY((_team) == SHMEM_TEAM_INVALID)) { \ + SPML_UCX_ERROR("Invalid team at %s", __func__); \ + return OSHMEM_ERROR; \ + } \ + } while (0) enum { SPML_UCX_STRONG_ORDERING_NONE = 0, /* don't use strong ordering */ @@ -115,6 +124,20 @@ typedef struct mca_spml_ucx_ctx_array { mca_spml_ucx_ctx_t **ctxs; } mca_spml_ucx_ctx_array_t; +typedef struct mca_spml_ucx_team_config { + shmem_team_config_t super; + +} mca_spml_ucx_team_config_t; + +typedef struct mca_spml_ucx_team { + int n_pes; + int my_pe; + int stride; + int start; + mca_spml_ucx_team_config_t *config; + struct mca_spml_ucx_team *parent_team; +} mca_spml_ucx_team_t; + struct mca_spml_ucx { mca_spml_base_module_t super; ucp_context_h ucp_context; diff --git a/oshmem/shmem/c/shmem_team.c b/oshmem/shmem/c/shmem_team.c index 7004080f869..0a4aacfa48d 100644 --- a/oshmem/shmem/c/shmem_team.c +++ b/oshmem/shmem/c/shmem_team.c @@ -51,14 +51,9 @@ void shmem_team_sync(shmem_team_t team) int shmem_team_my_pe(shmem_team_t team) { - int rc = 0; - RUNTIME_CHECK_INIT(); - rc = MCA_SPML_CALL(team_my_pe(team)); - RUNTIME_CHECK_IMPL_RC(rc); - - return rc; + return MCA_SPML_CALL(team_my_pe(team)); } int shmem_team_n_pes(shmem_team_t team) @@ -85,15 +80,11 @@ int shmem_team_get_config(shmem_team_t team, long config_mask, shmem_team_config } int shmem_team_translate_pe(shmem_team_t src_team, int src_pe, shmem_team_t dest_team) { - int rc = 0; - RUNTIME_CHECK_INIT(); - rc = MCA_SPML_CALL(team_translate_pe(src_team, src_pe, dest_team)); - RUNTIME_CHECK_IMPL_RC(rc); - - return rc; + return MCA_SPML_CALL(team_translate_pe(src_team, src_pe, dest_team)); } + int shmem_team_split_strided (shmem_team_t parent_team, int start, int stride, int size, const shmem_team_config_t *config, long config_mask, shmem_team_t *new_team)