@@ -66,6 +66,19 @@ static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target,
6666 return OMPI_SUCCESS ;
6767}
6868
69+ typedef struct {
70+ size_t dt_bytes ;
71+ uint64_t result_val ;
72+ void * result_addr ;
73+ } cas_result_handle_t ;
74+
75+ static void cas_result_cb (void * data ) {
76+ cas_result_handle_t * cas_result = (cas_result_handle_t * )data ;
77+ memcpy (cas_result -> result_addr , & cas_result -> result_val , cas_result -> dt_bytes );
78+ free (cas_result );
79+ }
80+
81+
6982static inline int create_iov_list (const void * addr , int count , ompi_datatype_t * datatype ,
7083 ucx_iovec_t * * ucx_iov , uint32_t * ucx_iov_count ) {
7184 int ret = OMPI_SUCCESS ;
@@ -808,14 +821,35 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a
808821 }
809822
810823 ompi_datatype_type_size (dt , & dt_bytes );
811- ret = opal_common_ucx_wpmem_cmpswp (module -> mem ,* (uint64_t * )compare_addr ,
812- * (uint64_t * )origin_addr , target ,
813- result_addr , dt_bytes , remote_addr );
824+ if (sizeof (uint64_t ) < dt_bytes ) {
825+ return OMPI_ERR_NOT_SUPPORTED ;
826+ }
827+
828+ uint64_t compare_val ;
829+ memcpy (& compare_val , compare_addr , dt_bytes );
830+
831+ // the completion handler copies the returned value into the result address
832+ // and free's the memory
833+ cas_result_handle_t * cas_result = malloc (sizeof (* cas_result ));
834+ cas_result -> dt_bytes = dt_bytes ;
835+ cas_result -> result_addr = result_addr ;
836+ memcpy (& cas_result -> result_val , origin_addr , dt_bytes );
837+ ret = opal_common_ucx_wpmem_fetch_nb (module -> mem , UCP_ATOMIC_FETCH_OP_CSWAP ,
838+ compare_val , target ,
839+ & cas_result -> result_val , dt_bytes , remote_addr ,
840+ & cas_result_cb , cas_result );
814841
815842 if (module -> acc_single_intrinsic ) {
816843 return ret ;
817844 }
818845
846+ // fence before releasing the accumulate lock
847+ ret = opal_common_ucx_wpmem_fence (module -> mem );
848+ if (ret != OMPI_SUCCESS ) {
849+ OSC_UCX_VERBOSE (1 , "opal_common_ucx_mem_fence failed: %d" , ret );
850+ // don't return error, try to release the accumulate lock
851+ }
852+
819853 return end_atomicity (module , target );
820854}
821855
0 commit comments