@@ -235,48 +235,72 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
235235 return ret ;
236236}
237237
238- static inline int start_atomicity (ompi_osc_ucx_module_t * module , int target ) {
238+ static inline bool need_acc_lock (ompi_osc_ucx_module_t * module , int target )
239+ {
240+ ompi_osc_ucx_lock_t * lock = NULL ;
241+ opal_hash_table_get_value_uint32 (& module -> outstanding_locks ,
242+ (uint32_t ) target , (void * * ) & lock );
243+
244+ /* if there is an exclusive lock there is no need to acqurie the accumulate lock */
245+ return !(NULL != lock && LOCK_EXCLUSIVE == lock -> type );
246+ }
247+
248+ static inline int start_atomicity (
249+ ompi_osc_ucx_module_t * module ,
250+ int target ,
251+ bool * lock_acquired ) {
239252 uint64_t result_value = -1 ;
240253 uint64_t remote_addr = (module -> state_addrs )[target ] + OSC_UCX_STATE_ACC_LOCK_OFFSET ;
241254 int ret = OMPI_SUCCESS ;
242255
243- for (;;) {
244- ret = opal_common_ucx_wpmem_cmpswp (module -> state_mem ,
245- TARGET_LOCK_UNLOCKED , TARGET_LOCK_EXCLUSIVE ,
246- target , & result_value , sizeof (result_value ),
247- remote_addr );
248- if (ret != OMPI_SUCCESS ) {
249- OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_cmpswp failed: %d" , ret );
250- return OMPI_ERROR ;
251- }
252- if (result_value == TARGET_LOCK_UNLOCKED ) {
253- return OMPI_SUCCESS ;
256+ if (need_acc_lock (module , target )) {
257+ for (;;) {
258+ ret = opal_common_ucx_wpmem_cmpswp (module -> state_mem ,
259+ TARGET_LOCK_UNLOCKED , TARGET_LOCK_EXCLUSIVE ,
260+ target , & result_value , sizeof (result_value ),
261+ remote_addr );
262+ if (ret != OMPI_SUCCESS ) {
263+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_cmpswp failed: %d" , ret );
264+ return OMPI_ERROR ;
265+ }
266+ if (result_value == TARGET_LOCK_UNLOCKED ) {
267+ return OMPI_SUCCESS ;
268+ }
269+
270+ ucp_worker_progress (mca_osc_ucx_component .wpool -> dflt_worker );
254271 }
255272
256- ucp_worker_progress (mca_osc_ucx_component .wpool -> dflt_worker );
273+ * lock_acquired = true;
274+ } else {
275+ * lock_acquired = false;
257276 }
258277}
259278
260279static inline int end_atomicity (
261280 ompi_osc_ucx_module_t * module ,
262281 int target ,
282+ bool lock_acquired ,
263283 void * free_ptr ) {
264284 uint64_t result_value = 0 ;
265285 uint64_t remote_addr = (module -> state_addrs )[target ] + OSC_UCX_STATE_ACC_LOCK_OFFSET ;
266286 int ret = OMPI_SUCCESS ;
267287
268- /* fence any still active operations */
269- ret = opal_common_ucx_wpmem_fence (module -> mem );
270- if (ret != OMPI_SUCCESS ) {
271- OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
272- return OMPI_ERROR ;
273- }
274-
275- ret = opal_common_ucx_wpmem_fetch (module -> state_mem ,
276- UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
277- target , & result_value , sizeof (result_value ),
278- remote_addr );
288+ if (lock_acquired ) {
289+ /* fence any still active operations */
290+ ret = opal_common_ucx_wpmem_fence (module -> mem );
291+ if (ret != OMPI_SUCCESS ) {
292+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
293+ return OMPI_ERROR ;
294+ }
279295
296+ ret = opal_common_ucx_wpmem_fetch (module -> state_mem ,
297+ UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
298+ target , & result_value , sizeof (result_value ),
299+ remote_addr );
300+ } else if (NULL != free_ptr ){
301+ /* flush before freeing the buffer */
302+ ret = opal_common_ucx_wpmem_flush (module -> state_mem , OPAL_COMMON_UCX_SCOPE_EP , target );
303+ }
280304 /* TODO: encapsulate in a request and make the release non-blocking */
281305 if (NULL != free_ptr ) {
282306 free (free_ptr );
@@ -562,6 +586,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
562586 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
563587 int ret = OMPI_SUCCESS ;
564588 void * free_ptr = NULL ;
589+ bool lock_acquired = false;
565590
566591 ret = check_sync_state (module , target , false);
567592 if (ret != OMPI_SUCCESS ) {
@@ -579,8 +604,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
579604 NULL , ucx_req );
580605 }
581606
582-
583- ret = start_atomicity (module , target );
607+ ret = start_atomicity (module , target , & lock_acquired );
584608 if (ret != OMPI_SUCCESS ) {
585609 return ret ;
586610 }
@@ -681,7 +705,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
681705 ompi_request_complete (& ucx_req -> super , true);
682706 }
683707
684- return end_atomicity (module , target , free_ptr );
708+ return end_atomicity (module , target , lock_acquired , free_ptr );
685709}
686710
687711int ompi_osc_ucx_accumulate (const void * origin_addr , int origin_count ,
@@ -701,17 +725,16 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
701725 uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
702726 size_t dt_bytes ;
703727 int ret = OMPI_SUCCESS ;
728+ bool lock_acquired = false;
704729
705730 ret = check_sync_state (module , target , false);
706731 if (ret != OMPI_SUCCESS ) {
707732 return ret ;
708733 }
709734
710- if (!module -> acc_single_intrinsic ) {
711- ret = start_atomicity (module , target );
712- if (ret != OMPI_SUCCESS ) {
713- return ret ;
714- }
735+ ret = start_atomicity (module , target , & lock_acquired );
736+ if (ret != OMPI_SUCCESS ) {
737+ return ret ;
715738 }
716739
717740 if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
@@ -738,7 +761,7 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
738761 return ret ;
739762 }
740763
741- return end_atomicity (module , target , NULL );
764+ return end_atomicity (module , target , lock_acquired , NULL );
742765}
743766
744767int ompi_osc_ucx_fetch_and_op (const void * origin_addr , void * result_addr ,
@@ -759,9 +782,10 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
759782 uint64_t value = origin_addr ? * (uint64_t * )origin_addr : 0 ;
760783 ucp_atomic_fetch_op_t opcode ;
761784 size_t dt_bytes ;
785+ bool lock_acquired = false;
762786
763787 if (!module -> acc_single_intrinsic ) {
764- ret = start_atomicity (module , target );
788+ ret = start_atomicity (module , target , & lock_acquired );
765789 if (ret != OMPI_SUCCESS ) {
766790 return ret ;
767791 }
@@ -792,7 +816,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
792816 return ret ;
793817 }
794818
795- return end_atomicity (module , target , NULL );
819+ return end_atomicity (module , target , lock_acquired , NULL );
796820 } else {
797821 return ompi_osc_ucx_get_accumulate (origin_addr , 1 , dt , result_addr , 1 , dt ,
798822 target , target_disp , 1 , dt , op , win );
@@ -811,6 +835,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
811835 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
812836 int ret = OMPI_SUCCESS ;
813837 void * free_addr = NULL ;
838+ bool lock_acquired = false;
814839
815840 ret = check_sync_state (module , target , false);
816841 if (ret != OMPI_SUCCESS ) {
@@ -824,7 +849,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
824849 result_addr , ucx_req );
825850 }
826851
827- ret = start_atomicity (module , target );
852+ ret = start_atomicity (module , target , & lock_acquired );
828853 if (ret != OMPI_SUCCESS ) {
829854 return ret ;
830855 }
@@ -933,7 +958,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
933958 }
934959
935960
936- return end_atomicity (module , target , free_addr );
961+ return end_atomicity (module , target , lock_acquired , free_addr );
937962}
938963
939964int ompi_osc_ucx_get_accumulate (const void * origin_addr , int origin_count ,
0 commit comments