Skip to content

Commit f568475

Browse files
authored
Merge pull request #3347 from alinask/topic/ucx-sync-send
PML UCX: handle a synchronous send.
2 parents 7ea0595 + 49913c6 commit f568475

File tree

1 file changed

+40
-45
lines changed

1 file changed

+40
-45
lines changed

ompi/mca/pml/ucx/pml_ucx.c

+40-45
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
601601
return OMPI_SUCCESS;
602602
}
603603

604-
static int
604+
static ucs_status_ptr_t
605605
mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
606606
ompi_datatype_t *datatype, uint64_t pml_tag)
607607
{
@@ -623,21 +623,21 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
623623
if (OPAL_UNLIKELY(NULL == packed_data)) {
624624
OBJ_DESTRUCT(&opal_conv);
625625
PML_UCX_ERROR("bsend: failed to allocate buffer");
626-
return OMPI_ERR_OUT_OF_RESOURCE;
626+
return UCS_STATUS_PTR(OMPI_ERROR);
627627
}
628628

629629
iov_count = 1;
630630
iov.iov_base = packed_data;
631631
iov.iov_len = packed_length;
632632

633-
PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d", packed_data, packed_length);
633+
PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %zu", packed_data, packed_length);
634634
offset = 0;
635635
opal_convertor_set_position(&opal_conv, &offset);
636636
if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
637637
mca_pml_base_bsend_request_free(packed_data);
638638
OBJ_DESTRUCT(&opal_conv);
639639
PML_UCX_ERROR("bsend: failed to pack user datatype");
640-
return OMPI_ERROR;
640+
return UCS_STATUS_PTR(OMPI_ERROR);
641641
}
642642

643643
OBJ_DESTRUCT(&opal_conv);
@@ -648,17 +648,34 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
648648
if (NULL == req) {
649649
/* request was completed in place */
650650
mca_pml_base_bsend_request_free(packed_data);
651-
return OMPI_SUCCESS;
651+
return NULL;
652652
}
653653

654654
if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
655655
mca_pml_base_bsend_request_free(packed_data);
656656
PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
657-
return OMPI_ERROR;
657+
return UCS_STATUS_PTR(OMPI_ERROR);
658658
}
659659

660660
req->req_complete_cb_data = packed_data;
661-
return OMPI_SUCCESS;
661+
return NULL;
662+
}
663+
664+
static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *buf,
665+
size_t count,
666+
ompi_datatype_t *datatype,
667+
ucp_datatype_t ucx_datatype,
668+
ucp_tag_t tag,
669+
mca_pml_base_send_mode_t mode,
670+
ucp_send_callback_t cb)
671+
{
672+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
673+
return mca_pml_ucx_bsend(ep, buf, count, datatype, tag);
674+
} else if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
675+
return ucp_tag_send_sync_nb(ep, buf, count, ucx_datatype, tag, cb);
676+
} else {
677+
return ucp_tag_send_nb(ep, buf, count, ucx_datatype, tag, cb);
678+
}
662679
}
663680

664681
int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
@@ -674,25 +691,17 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
674691
mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
675692
(void*)request)
676693

677-
/* TODO special care to sync/buffered send */
678-
679694
ep = mca_pml_ucx_get_ep(comm, dst);
680695
if (OPAL_UNLIKELY(NULL == ep)) {
681696
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
682697
return OMPI_ERROR;
683698
}
684699

685-
/* Special care to sync/buffered send */
686-
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
687-
*request = &ompi_pml_ucx.completed_send_req;
688-
return mca_pml_ucx_bsend(ep, buf, count, datatype,
689-
PML_UCX_MAKE_SEND_TAG(tag, comm));
690-
}
700+
req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
701+
mca_pml_ucx_get_datatype(datatype),
702+
PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
703+
mca_pml_ucx_send_completion);
691704

692-
req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
693-
mca_pml_ucx_get_datatype(datatype),
694-
PML_UCX_MAKE_SEND_TAG(tag, comm),
695-
mca_pml_ucx_send_completion);
696705
if (req == NULL) {
697706
PML_UCX_VERBOSE(8, "returning completed request");
698707
*request = &ompi_pml_ucx.completed_send_req;
@@ -723,16 +732,11 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
723732
return OMPI_ERROR;
724733
}
725734

726-
/* Special care to sync/buffered send */
727-
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
728-
return mca_pml_ucx_bsend(ep, buf, count, datatype,
729-
PML_UCX_MAKE_SEND_TAG(tag, comm));
730-
}
735+
req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
736+
mca_pml_ucx_get_datatype(datatype),
737+
PML_UCX_MAKE_SEND_TAG(tag, comm),
738+
mode, mca_pml_ucx_send_completion);
731739

732-
req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
733-
mca_pml_ucx_get_datatype(datatype),
734-
PML_UCX_MAKE_SEND_TAG(tag, comm),
735-
mca_pml_ucx_send_completion);
736740
if (OPAL_LIKELY(req == NULL)) {
737741
return OMPI_SUCCESS;
738742
} else if (!UCS_PTR_IS_ERR(req)) {
@@ -891,7 +895,6 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
891895
mca_pml_ucx_persistent_request_t *preq;
892896
ompi_request_t *tmp_req;
893897
size_t i;
894-
int rc;
895898

896899
for (i = 0; i < count; ++i) {
897900
preq = (mca_pml_ucx_persistent_request_t *)requests[i];
@@ -906,22 +909,14 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
906909
mca_pml_ucx_request_reset(&preq->ompi);
907910

908911
if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
909-
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) {
910-
PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq);
911-
rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count,
912-
preq->ompi_datatype, preq->tag);
913-
if (OMPI_SUCCESS != rc) {
914-
return rc;
915-
}
916-
/* pretend that we got immediate completion */
917-
tmp_req = NULL;
918-
} else {
919-
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
920-
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
921-
preq->count, preq->datatype,
922-
preq->tag,
923-
mca_pml_ucx_psend_completion);
924-
}
912+
tmp_req = (ompi_request_t*)mca_pml_ucx_common_send(preq->send.ep,
913+
preq->buffer,
914+
preq->count,
915+
preq->ompi_datatype,
916+
preq->datatype,
917+
preq->tag,
918+
preq->send.mode,
919+
mca_pml_ucx_psend_completion);
925920
} else {
926921
PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
927922
tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,

0 commit comments

Comments
 (0)