Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config/ompi_check_ucx.m4
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ AC_DEFUN([OMPI_CHECK_UCX],[
[AC_DEFINE([HAVE_UCP_WORKER_ADDRESS_FLAGS], [1],
[have worker address attribute])], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([UCP_ATTR_FIELD_MEMORY_TYPES],
[AC_DEFINE([HAVE_UCP_ATTR_MEMORY_TYPES], [1],
[have memory types attribute])], [],
[#include <ucp/api/ucp.h>])
AC_CHECK_DECLS([ucp_tag_send_nbx,
ucp_tag_send_sync_nbx,
ucp_tag_recv_nbx],
Expand Down
20 changes: 19 additions & 1 deletion ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
#include "ompi/message/message.h"
#include "ompi/mca/pml/base/pml_base_bsend.h"
#include "opal/mca/common/ucx/common_ucx.h"
#if OPAL_CUDA_SUPPORT
#include "opal/mca/common/cuda/common_cuda.h"
#endif /* OPAL_CUDA_SUPPORT */
#include "pml_ucx_request.h"

#include <inttypes.h>
Expand Down Expand Up @@ -227,22 +230,37 @@ int mca_pml_ucx_open(void)

/* Query UCX attributes */
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
#if HAVE_UCP_ATTR_MEMORY_TYPES
attr.field_mask |= UCP_ATTR_FIELD_MEMORY_TYPES;
#endif
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
if (UCS_OK != status) {
ucp_cleanup(ompi_pml_ucx.ucp_context);
ompi_pml_ucx.ucp_context = NULL;
return OMPI_ERROR;
}

ompi_pml_ucx.request_size = attr.request_size;
ompi_pml_ucx.request_size = attr.request_size;
ompi_pml_ucx.cuda_initialized = false;

#if HAVE_UCP_ATTR_MEMORY_TYPES && OPAL_CUDA_SUPPORT
if (attr.memory_types & UCS_BIT(UCS_MEMORY_TYPE_CUDA)) {
mca_common_cuda_stage_one_init();
ompi_pml_ucx.cuda_initialized = true;
}
#endif
return OMPI_SUCCESS;
}

int mca_pml_ucx_close(void)
{
PML_UCX_VERBOSE(1, "mca_pml_ucx_close");

#if OPAL_CUDA_SUPPORT
if (ompi_pml_ucx.cuda_initialized) {
mca_common_cuda_fini();
}
#endif
if (ompi_pml_ucx.ucp_context != NULL) {
ucp_cleanup(ompi_pml_ucx.ucp_context);
ompi_pml_ucx.ucp_context = NULL;
Expand Down
1 change: 1 addition & 0 deletions ompi/mca/pml/ucx/pml_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct mca_pml_ucx_module {
mca_pml_ucx_freelist_t convs;

int priority;
bool cuda_initialized;
};

extern mca_pml_base_component_2_0_0_t mca_pml_ucx_component;
Expand Down