|
22 | 22 | #include "ompi/message/message.h"
|
23 | 23 | #include "ompi/mca/pml/base/pml_base_bsend.h"
|
24 | 24 | #include "opal/mca/common/ucx/common_ucx.h"
|
| 25 | +#if OPAL_CUDA_SUPPORT |
| 26 | +#include "opal/mca/common/cuda/common_cuda.h" |
| 27 | +#endif /* OPAL_CUDA_SUPPORT */ |
25 | 28 | #include "pml_ucx_request.h"
|
26 | 29 |
|
27 | 30 | #include <inttypes.h>
|
@@ -230,22 +233,37 @@ int mca_pml_ucx_open(void)
|
230 | 233 |
|
231 | 234 | /* Query UCX attributes */
|
232 | 235 | attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
|
| 236 | +#if HAVE_UCP_ATTR_MEMORY_TYPES && OPAL_CUDA_SUPPORT |
| 237 | + attr.field_mask |= UCP_ATTR_FIELD_MEMORY_TYPES; |
| 238 | +#endif |
233 | 239 | status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
|
234 | 240 | if (UCS_OK != status) {
|
235 | 241 | ucp_cleanup(ompi_pml_ucx.ucp_context);
|
236 | 242 | ompi_pml_ucx.ucp_context = NULL;
|
237 | 243 | return OMPI_ERROR;
|
238 | 244 | }
|
239 | 245 |
|
240 |
| - ompi_pml_ucx.request_size = attr.request_size; |
| 246 | + ompi_pml_ucx.request_size = attr.request_size; |
| 247 | + ompi_pml_ucx.cuda_initialized = false; |
241 | 248 |
|
| 249 | +#if HAVE_UCP_ATTR_MEMORY_TYPES && OPAL_CUDA_SUPPORT |
| 250 | + if (attr.memory_types & UCS_BIT(UCS_MEMORY_TYPE_CUDA)) { |
| 251 | + mca_common_cuda_stage_one_init(); |
| 252 | + ompi_pml_ucx.cuda_initialized = true; |
| 253 | + } |
| 254 | +#endif |
242 | 255 | return OMPI_SUCCESS;
|
243 | 256 | }
|
244 | 257 |
|
245 | 258 | int mca_pml_ucx_close(void)
|
246 | 259 | {
|
247 | 260 | PML_UCX_VERBOSE(1, "mca_pml_ucx_close");
|
248 | 261 |
|
| 262 | +#if OPAL_CUDA_SUPPORT |
| 263 | + if (ompi_pml_ucx.cuda_initialized) { |
| 264 | + mca_common_cuda_fini(); |
| 265 | + } |
| 266 | +#endif |
249 | 267 | if (ompi_pml_ucx.ucp_context != NULL) {
|
250 | 268 | ucp_cleanup(ompi_pml_ucx.ucp_context);
|
251 | 269 | ompi_pml_ucx.ucp_context = NULL;
|
|
0 commit comments