@@ -601,7 +601,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
601
601
return OMPI_SUCCESS ;
602
602
}
603
603
604
- static int
604
+ static ucs_status_ptr_t
605
605
mca_pml_ucx_bsend (ucp_ep_h ep , const void * buf , size_t count ,
606
606
ompi_datatype_t * datatype , uint64_t pml_tag )
607
607
{
@@ -623,21 +623,21 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
623
623
if (OPAL_UNLIKELY (NULL == packed_data )) {
624
624
OBJ_DESTRUCT (& opal_conv );
625
625
PML_UCX_ERROR ("bsend: failed to allocate buffer" );
626
- return OMPI_ERR_OUT_OF_RESOURCE ;
626
+ return UCS_STATUS_PTR ( OMPI_ERROR ) ;
627
627
}
628
628
629
629
iov_count = 1 ;
630
630
iov .iov_base = packed_data ;
631
631
iov .iov_len = packed_length ;
632
632
633
- PML_UCX_VERBOSE (8 , "bsend of packed buffer %p len %d " , packed_data , packed_length );
633
+ PML_UCX_VERBOSE (8 , "bsend of packed buffer %p len %zu " , packed_data , packed_length );
634
634
offset = 0 ;
635
635
opal_convertor_set_position (& opal_conv , & offset );
636
636
if (0 > opal_convertor_pack (& opal_conv , & iov , & iov_count , & packed_length )) {
637
637
mca_pml_base_bsend_request_free (packed_data );
638
638
OBJ_DESTRUCT (& opal_conv );
639
639
PML_UCX_ERROR ("bsend: failed to pack user datatype" );
640
- return OMPI_ERROR ;
640
+ return UCS_STATUS_PTR ( OMPI_ERROR ) ;
641
641
}
642
642
643
643
OBJ_DESTRUCT (& opal_conv );
@@ -648,17 +648,34 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
648
648
if (NULL == req ) {
649
649
/* request was completed in place */
650
650
mca_pml_base_bsend_request_free (packed_data );
651
- return OMPI_SUCCESS ;
651
+ return NULL ;
652
652
}
653
653
654
654
if (OPAL_UNLIKELY (UCS_PTR_IS_ERR (req ))) {
655
655
mca_pml_base_bsend_request_free (packed_data );
656
656
PML_UCX_ERROR ("ucx bsend failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
657
- return OMPI_ERROR ;
657
+ return UCS_STATUS_PTR ( OMPI_ERROR ) ;
658
658
}
659
659
660
660
req -> req_complete_cb_data = packed_data ;
661
- return OMPI_SUCCESS ;
661
+ return NULL ;
662
+ }
663
+
664
+ static inline ucs_status_ptr_t mca_pml_ucx_common_send (ucp_ep_h ep , const void * buf ,
665
+ size_t count ,
666
+ ompi_datatype_t * datatype ,
667
+ ucp_datatype_t ucx_datatype ,
668
+ ucp_tag_t tag ,
669
+ mca_pml_base_send_mode_t mode ,
670
+ ucp_send_callback_t cb )
671
+ {
672
+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
673
+ return mca_pml_ucx_bsend (ep , buf , count , datatype , tag );
674
+ } else if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_SYNCHRONOUS == mode )) {
675
+ return ucp_tag_send_sync_nb (ep , buf , count , ucx_datatype , tag , cb );
676
+ } else {
677
+ return ucp_tag_send_nb (ep , buf , count , ucx_datatype , tag , cb );
678
+ }
662
679
}
663
680
664
681
int mca_pml_ucx_isend (const void * buf , size_t count , ompi_datatype_t * datatype ,
@@ -674,25 +691,17 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
674
691
mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "" ,
675
692
(void * )request )
676
693
677
- /* TODO special care to sync/buffered send */
678
-
679
694
ep = mca_pml_ucx_get_ep (comm , dst );
680
695
if (OPAL_UNLIKELY (NULL == ep )) {
681
696
PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
682
697
return OMPI_ERROR ;
683
698
}
684
699
685
- /* Special care to sync/buffered send */
686
- if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
687
- * request = & ompi_pml_ucx .completed_send_req ;
688
- return mca_pml_ucx_bsend (ep , buf , count , datatype ,
689
- PML_UCX_MAKE_SEND_TAG (tag , comm ));
690
- }
700
+ req = (ompi_request_t * )mca_pml_ucx_common_send (ep , buf , count , datatype ,
701
+ mca_pml_ucx_get_datatype (datatype ),
702
+ PML_UCX_MAKE_SEND_TAG (tag , comm ), mode ,
703
+ mca_pml_ucx_send_completion );
691
704
692
- req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
693
- mca_pml_ucx_get_datatype (datatype ),
694
- PML_UCX_MAKE_SEND_TAG (tag , comm ),
695
- mca_pml_ucx_send_completion );
696
705
if (req == NULL ) {
697
706
PML_UCX_VERBOSE (8 , "returning completed request" );
698
707
* request = & ompi_pml_ucx .completed_send_req ;
@@ -723,16 +732,11 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
723
732
return OMPI_ERROR ;
724
733
}
725
734
726
- /* Special care to sync/buffered send */
727
- if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
728
- return mca_pml_ucx_bsend (ep , buf , count , datatype ,
729
- PML_UCX_MAKE_SEND_TAG (tag , comm ));
730
- }
735
+ req = (ompi_request_t * )mca_pml_ucx_common_send (ep , buf , count , datatype ,
736
+ mca_pml_ucx_get_datatype (datatype ),
737
+ PML_UCX_MAKE_SEND_TAG (tag , comm ),
738
+ mode , mca_pml_ucx_send_completion );
731
739
732
- req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
733
- mca_pml_ucx_get_datatype (datatype ),
734
- PML_UCX_MAKE_SEND_TAG (tag , comm ),
735
- mca_pml_ucx_send_completion );
736
740
if (OPAL_LIKELY (req == NULL )) {
737
741
return OMPI_SUCCESS ;
738
742
} else if (!UCS_PTR_IS_ERR (req )) {
@@ -891,7 +895,6 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
891
895
mca_pml_ucx_persistent_request_t * preq ;
892
896
ompi_request_t * tmp_req ;
893
897
size_t i ;
894
- int rc ;
895
898
896
899
for (i = 0 ; i < count ; ++ i ) {
897
900
preq = (mca_pml_ucx_persistent_request_t * )requests [i ];
@@ -906,22 +909,14 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
906
909
mca_pml_ucx_request_reset (& preq -> ompi );
907
910
908
911
if (preq -> flags & MCA_PML_UCX_REQUEST_FLAG_SEND ) {
909
- if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == preq -> send .mode )) {
910
- PML_UCX_VERBOSE (8 , "start bsend request %p" , (void * )preq );
911
- rc = mca_pml_ucx_bsend (preq -> send .ep , preq -> buffer , preq -> count ,
912
- preq -> ompi_datatype , preq -> tag );
913
- if (OMPI_SUCCESS != rc ) {
914
- return rc ;
915
- }
916
- /* pretend that we got immediate completion */
917
- tmp_req = NULL ;
918
- } else {
919
- PML_UCX_VERBOSE (8 , "start send request %p" , (void * )preq );
920
- tmp_req = (ompi_request_t * )ucp_tag_send_nb (preq -> send .ep , preq -> buffer ,
921
- preq -> count , preq -> datatype ,
922
- preq -> tag ,
923
- mca_pml_ucx_psend_completion );
924
- }
912
+ tmp_req = (ompi_request_t * )mca_pml_ucx_common_send (preq -> send .ep ,
913
+ preq -> buffer ,
914
+ preq -> count ,
915
+ preq -> ompi_datatype ,
916
+ preq -> datatype ,
917
+ preq -> tag ,
918
+ preq -> send .mode ,
919
+ mca_pml_ucx_psend_completion );
925
920
} else {
926
921
PML_UCX_VERBOSE (8 , "start recv request %p" , (void * )preq );
927
922
tmp_req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker ,
0 commit comments