@@ -323,7 +323,7 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
323323 return ret ;
324324}
325325
326- static int atomic_op_replace_sum (
326+ static int do_atomic_op_replace_sum (
327327 ompi_osc_ucx_module_t * module ,
328328 struct ompi_op_t * op ,
329329 int target ,
@@ -333,7 +333,8 @@ static int atomic_op_replace_sum(
333333 ptrdiff_t target_disp ,
334334 int target_count ,
335335 struct ompi_datatype_t * target_dt ,
336- void * result_addr )
336+ void * result_addr ,
337+ ompi_osc_ucx_request_t * ucx_req )
337338{
338339 int ret = OMPI_SUCCESS ;
339340 size_t origin_dt_bytes ;
@@ -363,12 +364,27 @@ static int atomic_op_replace_sum(
363364 opcode = UCP_ATOMIC_FETCH_OP_FADD ;
364365 }
365366
367+ opal_common_ucx_user_req_handler_t user_req_cb = NULL ;
368+ void * user_req_ptr = NULL ;
366369 for (int i = 0 ; i < origin_count ; ++ i ) {
367370 uint64_t value = 0 ;
371+ if ((origin_count - 1 ) == i && NULL != ucx_req ) {
372+ // the last item is used to feed the request, if needed
373+ user_req_cb = & req_completion ;
374+ user_req_ptr = ucx_req ;
375+ // issue a fence if this is the last but not the only element
376+ if (0 < i ) {
377+ ret = opal_common_ucx_wpmem_fence (module -> mem );
378+ if (ret != OMPI_SUCCESS ) {
379+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
380+ return OMPI_ERROR ;
381+ }
382+ }
383+ }
368384 memcpy (& value , origin_addr , origin_dt_bytes );
369385 ret = opal_common_ucx_wpmem_fetch_nb (module -> mem , opcode , value , target ,
370386 result_addr ? result_addr : & (module -> req_result ),
371- origin_dt_bytes , remote_addr , NULL , NULL );
387+ origin_dt_bytes , remote_addr , user_req_cb , user_req_ptr );
372388
373389 // advance origin and remote address
374390 origin_addr = (void * )((intptr_t )origin_addr + origin_dt_bytes );
@@ -381,7 +397,7 @@ static int atomic_op_replace_sum(
381397 return ret ;
382398}
383399
384- static int atomic_op_cswap (
400+ static int do_atomic_op_cswap (
385401 ompi_osc_ucx_module_t * module ,
386402 struct ompi_op_t * op ,
387403 int target ,
@@ -391,7 +407,8 @@ static int atomic_op_cswap(
391407 ptrdiff_t target_disp ,
392408 int target_count ,
393409 struct ompi_datatype_t * target_dt ,
394- void * result_addr )
410+ void * result_addr ,
411+ ompi_osc_ucx_request_t * ucx_req )
395412{
396413 int ret = OMPI_SUCCESS ;
397414 size_t origin_dt_bytes ;
@@ -432,6 +449,7 @@ static int atomic_op_cswap(
432449 return ret ;
433450 }
434451
452+ /* JS: move this loop into the request to overlap multiple cas operations? */
435453 do {
436454
437455 tmp_val = target_val ;
@@ -451,6 +469,8 @@ static int atomic_op_cswap(
451469 break ;
452470 }
453471
472+ target_val = tmp_val ;
473+
454474 } while (1 );
455475
456476 // store the result if necessary
@@ -463,6 +483,41 @@ static int atomic_op_cswap(
463483 remote_addr += origin_dt_bytes ;
464484 }
465485
486+ if (NULL != ucx_req ) {
487+ // nothing to wait for so mark the request as completed
488+ ompi_request_complete (& ucx_req -> super , true);
489+ }
490+
491+ return ret ;
492+ }
493+
494+ static inline
495+ int do_atomic_op (
496+ ompi_osc_ucx_module_t * module ,
497+ struct ompi_op_t * op ,
498+ int target ,
499+ const void * origin_addr ,
500+ int origin_count ,
501+ struct ompi_datatype_t * origin_dt ,
502+ ptrdiff_t target_disp ,
503+ int target_count ,
504+ struct ompi_datatype_t * target_dt ,
505+ void * result_addr ,
506+ ompi_osc_ucx_request_t * ucx_req )
507+ {
508+ int ret ;
509+
510+ if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
511+ ret = do_atomic_op_replace_sum (module , op , target ,
512+ origin_addr , origin_count , origin_dt ,
513+ target_disp , target_count , target_dt ,
514+ result_addr , ucx_req );
515+ } else {
516+ ret = do_atomic_op_cswap (module , op , target ,
517+ origin_addr , origin_count , origin_dt ,
518+ target_disp , target_count , target_dt ,
519+ result_addr , ucx_req );
520+ }
466521 return ret ;
467522}
468523
@@ -576,11 +631,14 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
576631 }
577632}
578633
579- int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
580- struct ompi_datatype_t * origin_dt ,
581- int target , ptrdiff_t target_disp , int target_count ,
582- struct ompi_datatype_t * target_dt ,
583- struct ompi_op_t * op , struct ompi_win_t * win ) {
634+ static
635+ int accumulate_req (const void * origin_addr , int origin_count ,
636+ struct ompi_datatype_t * origin_dt ,
637+ int target , ptrdiff_t target_disp , int target_count ,
638+ struct ompi_datatype_t * target_dt ,
639+ struct ompi_op_t * op , struct ompi_win_t * win ,
640+ ompi_osc_ucx_request_t * ucx_req ) {
641+
584642 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
585643 int ret = OMPI_SUCCESS ;
586644
@@ -594,18 +652,10 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
594652 }
595653
596654 if (module -> acc_single_intrinsic ) {
597- if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
598- ret = atomic_op_replace_sum (module , op , target ,
599- origin_addr , origin_count , origin_dt ,
600- target_disp , target_count , target_dt ,
601- & (module -> req_result ));
602- } else {
603- ret = atomic_op_cswap (module , op , target ,
604- origin_addr , origin_count , origin_dt ,
605- target_disp , target_count , target_dt ,
606- & (module -> req_result ));
607- }
608- return ret ;
655+ return do_atomic_op (module , op , target ,
656+ origin_addr , origin_count , origin_dt ,
657+ target_disp , target_count , target_dt ,
658+ NULL , ucx_req );
609659 }
610660
611661
@@ -712,9 +762,23 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
712762 free (temp_addr_holder );
713763 }
714764
765+ if (NULL != ucx_req ) {
766+ // nothing to wait for, mark request as completed
767+ ompi_request_complete (& ucx_req -> super , true);
768+ }
769+
715770 return end_atomicity (module , target );
716771}
717772
773+ int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
774+ struct ompi_datatype_t * origin_dt ,
775+ int target , ptrdiff_t target_disp , int target_count ,
776+ struct ompi_datatype_t * target_dt ,
777+ struct ompi_op_t * op , struct ompi_win_t * win ) {
778+ return accumulate_req (origin_addr , origin_count , origin_dt , target ,
779+ target_disp , target_count , target_dt , op , win , NULL );
780+ }
781+
718782int ompi_osc_ucx_compare_and_swap (const void * origin_addr , const void * compare_addr ,
719783 void * result_addr , struct ompi_datatype_t * dt ,
720784 int target , ptrdiff_t target_disp ,
@@ -813,13 +877,15 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
813877 }
814878}
815879
816- int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
817- struct ompi_datatype_t * origin_dt ,
818- void * result_addr , int result_count ,
819- struct ompi_datatype_t * result_dt ,
820- int target , ptrdiff_t target_disp ,
821- int target_count , struct ompi_datatype_t * target_dt ,
822- struct ompi_op_t * op , struct ompi_win_t * win ) {
880+ static
881+ int get_accumulate_req (const void * origin_addr , int origin_count ,
882+ struct ompi_datatype_t * origin_dt ,
883+ void * result_addr , int result_count ,
884+ struct ompi_datatype_t * result_dt ,
885+ int target , ptrdiff_t target_disp ,
886+ int target_count , struct ompi_datatype_t * target_dt ,
887+ struct ompi_op_t * op , struct ompi_win_t * win ,
888+ ompi_osc_ucx_request_t * ucx_req ) {
823889 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
824890 int ret = OMPI_SUCCESS ;
825891
@@ -829,19 +895,12 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
829895 }
830896
831897 if (module -> acc_single_intrinsic ) {
832- if (op == & ompi_mpi_op_replace .op || op == & ompi_mpi_op_sum .op ) {
833- ret = atomic_op_replace_sum (module , op , target ,
834- origin_addr , origin_count , origin_dt ,
835- target_disp , target_count , target_dt , result_addr );
836- } else {
837- ret = atomic_op_cswap (module , op , target ,
838- origin_addr , origin_count , origin_dt ,
839- target_disp , target_count , target_dt , result_addr );
840- }
841- return ret ;
898+ return do_atomic_op (module , op , target ,
899+ origin_addr , origin_count , origin_dt ,
900+ target_disp , target_count , target_dt ,
901+ result_addr , ucx_req );
842902 }
843903
844-
845904 ret = start_atomicity (module , target );
846905 if (ret != OMPI_SUCCESS ) {
847906 return ret ;
@@ -953,9 +1012,28 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
9531012 }
9541013 }
9551014
1015+ if (NULL != ucx_req ) {
1016+ // nothing to wait for, mark request as completed
1017+ ompi_request_complete (& ucx_req -> super , true);
1018+ }
1019+
1020+
9561021 return end_atomicity (module , target );
9571022}
9581023
1024+ int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
1025+ struct ompi_datatype_t * origin_dt ,
1026+ void * result_addr , int result_count ,
1027+ struct ompi_datatype_t * result_dt ,
1028+ int target , ptrdiff_t target_disp ,
1029+ int target_count , struct ompi_datatype_t * target_dt ,
1030+ struct ompi_op_t * op , struct ompi_win_t * win ) {
1031+
1032+ return get_accumulate_req (origin_addr , origin_count , origin_dt , result_addr ,
1033+ result_count , result_dt , target , target_disp ,
1034+ target_count , target_dt , op , win , NULL );
1035+ }
1036+
9591037int ompi_osc_ucx_rput (const void * origin_addr , int origin_count ,
9601038 struct ompi_datatype_t * origin_dt ,
9611039 int target , ptrdiff_t target_disp , int target_count ,
@@ -1077,14 +1155,13 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
10771155 OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
10781156 assert (NULL != ucx_req );
10791157
1080- ret = ompi_osc_ucx_accumulate (origin_addr , origin_count , origin_dt , target , target_disp ,
1081- target_count , target_dt , op , win );
1158+ ret = accumulate_req (origin_addr , origin_count , origin_dt , target , target_disp ,
1159+ target_count , target_dt , op , win , ucx_req );
10821160 if (ret != OMPI_SUCCESS ) {
10831161 OMPI_OSC_UCX_REQUEST_RETURN (ucx_req );
10841162 return ret ;
10851163 }
10861164
1087- ompi_request_complete (& ucx_req -> super , true);
10881165 * request = & ucx_req -> super ;
10891166
10901167 return ret ;
@@ -1110,17 +1187,15 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
11101187 OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
11111188 assert (NULL != ucx_req );
11121189
1113- ret = ompi_osc_ucx_get_accumulate (origin_addr , origin_count , origin_datatype ,
1114- result_addr , result_count , result_datatype ,
1115- target , target_disp , target_count ,
1116- target_datatype , op , win );
1190+ ret = get_accumulate_req (origin_addr , origin_count , origin_datatype ,
1191+ result_addr , result_count , result_datatype ,
1192+ target , target_disp , target_count ,
1193+ target_datatype , op , win , ucx_req );
11171194 if (ret != OMPI_SUCCESS ) {
11181195 OMPI_OSC_UCX_REQUEST_RETURN (ucx_req );
11191196 return ret ;
11201197 }
11211198
1122- ompi_request_complete (& ucx_req -> super , true);
1123-
11241199 * request = & ucx_req -> super ;
11251200
11261201 return ret ;
0 commit comments