diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 17071cadc13..431a86b9445 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -142,11 +142,78 @@ static int progress_callback(void) { return 0; } +static int ucp_context_init(void) { + int ret = OMPI_SUCCESS; + ucs_status_t status; + ucp_config_t *config = NULL; + ucp_params_t context_params; + + status = ucp_config_read("MPI", NULL, &config); + if (UCS_OK != status) { + OSC_UCX_VERBOSE(1, "ucp_config_read failed: %d", status); + return OMPI_ERROR; + } + + /* initialize UCP context */ + memset(&context_params, 0, sizeof(context_params)); + context_params.field_mask = UCP_PARAM_FIELD_FEATURES | + UCP_PARAM_FIELD_MT_WORKERS_SHARED | + UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | + UCP_PARAM_FIELD_REQUEST_INIT | + UCP_PARAM_FIELD_REQUEST_SIZE; + context_params.features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64; + context_params.mt_workers_shared = 0; + context_params.estimated_num_eps = ompi_proc_world_size(); + context_params.request_init = internal_req_init; + context_params.request_size = sizeof(ompi_osc_ucx_internal_request_t); + + status = ucp_init(&context_params, config, &mca_osc_ucx_component.ucp_context); + ucp_config_release(config); + if (UCS_OK != status) { + OSC_UCX_VERBOSE(1, "ucp_init failed: %d", status); + ret = OMPI_ERROR; + } + + return ret; +} static int component_init(bool enable_progress_threads, bool enable_mpi_threads) { + opal_common_ucx_support_level_t support_level = OPAL_COMMON_UCX_SUPPORT_NONE; + mca_base_var_source_t param_source = MCA_BASE_VAR_SOURCE_DEFAULT; + int ret = OMPI_SUCCESS, + param = -1; + mca_osc_ucx_component.enable_mpi_threads = enable_mpi_threads; opal_common_ucx_mca_register(); - return OMPI_SUCCESS; + + ret = ucp_context_init(); + if (OMPI_ERROR == ret) { + return OMPI_ERR_NOT_AVAILABLE; + } + + support_level = opal_common_ucx_support_level(mca_osc_ucx_component.ucp_context); + if (OPAL_COMMON_UCX_SUPPORT_NONE == support_level) { + ucp_cleanup(mca_osc_ucx_component.ucp_context); + mca_osc_ucx_component.ucp_context = NULL; + return OMPI_ERR_NOT_AVAILABLE; + } + + param = mca_base_var_find("ompi","osc","ucx","priority"); + if (0 <= param) { + (void) mca_base_var_get_value(param, NULL, ¶m_source, NULL); + } + + /* + * Retain priority if we have supported devices and transports. + * Lower priority if we have supported transports, but not supported devices. + */ + if (MCA_BASE_VAR_SOURCE_DEFAULT == param_source) { + mca_osc_ucx_component.priority = (support_level == OPAL_COMMON_UCX_SUPPORT_DEVICE) ? + mca_osc_ucx_component.priority : 9; + OSC_UCX_VERBOSE(2, "returning priority %d", mca_osc_ucx_component.priority); + } + + return ret; } static int component_finalize(void) { @@ -165,7 +232,10 @@ static int component_finalize(void) { assert(mca_osc_ucx_component.num_incomplete_req_ops == 0); if (mca_osc_ucx_component.env_initialized == true) { OBJ_DESTRUCT(&mca_osc_ucx_component.requests); - ucp_cleanup(mca_osc_ucx_component.ucp_context); + if (NULL != mca_osc_ucx_component.ucp_context) { + ucp_cleanup(mca_osc_ucx_component.ucp_context); + mca_osc_ucx_component.ucp_context = NULL; + } mca_osc_ucx_component.env_initialized = false; } opal_common_ucx_mca_deregister(); @@ -317,18 +387,9 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in _osc_ucx_init_lock(); if (mca_osc_ucx_component.env_initialized == false) { - ucp_config_t *config = NULL; - ucp_params_t context_params; ucp_worker_params_t worker_params; ucp_worker_attr_t worker_attr; - status = ucp_config_read("MPI", NULL, &config); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_config_read failed: %d", status); - ret = OMPI_ERROR; - goto select_unlock; - } - OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t); ret = opal_free_list_init (&mca_osc_ucx_component.requests, sizeof(ompi_osc_ucx_request_t), @@ -340,28 +401,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto select_unlock; } - /* initialize UCP context */ - - memset(&context_params, 0, sizeof(context_params)); - context_params.field_mask = UCP_PARAM_FIELD_FEATURES | - UCP_PARAM_FIELD_MT_WORKERS_SHARED | - UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | - UCP_PARAM_FIELD_REQUEST_INIT | - UCP_PARAM_FIELD_REQUEST_SIZE; - context_params.features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64; - context_params.mt_workers_shared = 0; - context_params.estimated_num_eps = ompi_proc_world_size(); - context_params.request_init = internal_req_init; - context_params.request_size = sizeof(ompi_osc_ucx_internal_request_t); - - status = ucp_init(&context_params, config, &mca_osc_ucx_component.ucp_context); - ucp_config_release(config); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_init failed: %d", status); - ret = OMPI_ERROR; - goto select_unlock; - } - assert(mca_osc_ucx_component.ucp_worker == NULL); memset(&worker_params, 0, sizeof(worker_params)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;