Skip to content

Commit aa8f7f4

Browse files
authored
Merge pull request #7893 from bureddy/cuda-ucx
UCX: initialize cuda from ucx pml component
2 parents 1f237f5 + 2547e24 commit aa8f7f4

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

config/ompi_check_ucx.m4

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ AC_DEFUN([OMPI_CHECK_UCX],[
129129
[AC_DEFINE([HAVE_UCP_WORKER_ADDRESS_FLAGS], [1],
130130
[have worker address attribute])], [],
131131
[#include <ucp/api/ucp.h>])
132+
AC_CHECK_DECLS([UCP_ATTR_FIELD_MEMORY_TYPES],
133+
[AC_DEFINE([HAVE_UCP_ATTR_MEMORY_TYPES], [1],
134+
[have memory types attribute])], [],
135+
[#include <ucp/api/ucp.h>])
132136
AC_CHECK_DECLS([ucp_tag_send_nbx,
133137
ucp_tag_send_sync_nbx,
134138
ucp_tag_recv_nbx],

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#include "ompi/message/message.h"
2323
#include "ompi/mca/pml/base/pml_base_bsend.h"
2424
#include "opal/mca/common/ucx/common_ucx.h"
25+
#if OPAL_CUDA_SUPPORT
26+
#include "opal/mca/common/cuda/common_cuda.h"
27+
#endif /* OPAL_CUDA_SUPPORT */
2528
#include "pml_ucx_request.h"
2629

2730
#include <inttypes.h>
@@ -230,22 +233,37 @@ int mca_pml_ucx_open(void)
230233

231234
/* Query UCX attributes */
232235
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
236+
#if HAVE_UCP_ATTR_MEMORY_TYPES
237+
attr.field_mask |= UCP_ATTR_FIELD_MEMORY_TYPES;
238+
#endif
233239
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
234240
if (UCS_OK != status) {
235241
ucp_cleanup(ompi_pml_ucx.ucp_context);
236242
ompi_pml_ucx.ucp_context = NULL;
237243
return OMPI_ERROR;
238244
}
239245

240-
ompi_pml_ucx.request_size = attr.request_size;
246+
ompi_pml_ucx.request_size = attr.request_size;
247+
ompi_pml_ucx.cuda_initialized = false;
241248

249+
#if HAVE_UCP_ATTR_MEMORY_TYPES && OPAL_CUDA_SUPPORT
250+
if (attr.memory_types & UCS_BIT(UCS_MEMORY_TYPE_CUDA)) {
251+
mca_common_cuda_stage_one_init();
252+
ompi_pml_ucx.cuda_initialized = true;
253+
}
254+
#endif
242255
return OMPI_SUCCESS;
243256
}
244257

245258
int mca_pml_ucx_close(void)
246259
{
247260
PML_UCX_VERBOSE(1, "mca_pml_ucx_close");
248261

262+
#if OPAL_CUDA_SUPPORT
263+
if (ompi_pml_ucx.cuda_initialized) {
264+
mca_common_cuda_fini();
265+
}
266+
#endif
249267
if (ompi_pml_ucx.ucp_context != NULL) {
250268
ucp_cleanup(ompi_pml_ucx.ucp_context);
251269
ompi_pml_ucx.ucp_context = NULL;

ompi/mca/pml/ucx/pml_ucx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct mca_pml_ucx_module {
5757
mca_pml_ucx_freelist_t convs;
5858

5959
int priority;
60+
bool cuda_initialized;
6061
};
6162

6263
extern mca_pml_base_component_2_0_0_t mca_pml_ucx_component;

0 commit comments

Comments
 (0)